InstanceNorm{1,2,3}d (#28790)

Summary:
Hi yf225,

I have a few doubts related to implementation:
1) What tests do I have to write?
2) What does _load_state_from_dict does?
3) Do I need to override reset() function as I can not see it's utility?
4) InstanceNormOptions could be removed with BatchNormOptions, but I find that
`track_running_status` is not defined instead `stateful` is defined.

InstanceNorm{1,2,3}d https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28790

Differential Revision: D18588666

Pulled By: yf225

fbshipit-source-id: bb9b81f01f62c3fc8765fa0ba0716768087ee155
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index ba36dd0..c05af39 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -566,6 +566,7 @@
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/normalization.cpp
+      ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/instancenorm.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/distance.cpp
@@ -583,11 +584,11 @@
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/embedding.cpp
+      ${TORCH_SRC_DIR}/csrc/api/src/nn/options/instancenorm.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/conv.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp
-      ${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/padding.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp
       ${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index d4aca98..981c155 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -1579,6 +1579,236 @@
   ASSERT_TRUE(output.allclose(expected));
 }
 
+TEST_F(FunctionalTest, InstanceNorm1d) {
+  int num_features = 5;
+  double eps = 1e-05;
+  double momentum = 0.1;
+
+  auto input = torch::arange(40.).view({2, 5, 4});
+  auto mean = torch::arange(5.);
+  auto variance = torch::arange(5.);
+  auto weight = torch::arange((double)num_features);
+  auto bias = torch::arange((double)num_features);
+  auto output = F::instance_norm(
+    input,
+    F::InstanceNormFuncOptions()
+      .running_mean(mean)
+      .running_var(variance)
+      .weight(weight)
+      .bias(bias)
+      .momentum(momentum)
+      .eps(eps));
+  auto expected = torch::tensor({{{ 0.0000,  0.0000,  0.0000,  0.0000},
+                                  {-0.3416,  0.5528,  1.4472,  2.3416},
+                                  {-0.6833,  1.1056,  2.8944,  4.6833},
+                                  {-1.0249,  1.6584,  4.3416,  7.0249},
+                                  {-1.3665,  2.2112,  5.7888,  9.3665}},
+                                 {{ 0.0000,  0.0000,  0.0000,  0.0000},
+                                  {-0.3416,  0.5528,  1.4472,  2.3416},
+                                  {-0.6833,  1.1056,  2.8944,  4.6833},
+                                  {-1.0249,  1.6584,  4.3416,  7.0249},
+                                  {-1.3665,  2.2112,  5.7888,  9.3665}}});
+  ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
+  auto input = torch::arange(40.).view({2, 5, 4});
+  auto output = F::instance_norm(input);
+  auto expected = torch::tensor({{{-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416}},
+                                 {{-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416},
+                                  {-1.3416, -0.4472,  0.4472,  1.3416}}});
+  ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm2d) {
+  int num_features = 5;
+  double eps = 1e-05;
+  double momentum = 0.1;
+
+  auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
+  auto mean = torch::arange((double)num_features);
+  auto variance = torch::arange((double)num_features);
+  auto weight = torch::arange((double)num_features);
+  auto bias = torch::arange((double)num_features);
+  auto output = F::instance_norm(
+    input,
+    F::InstanceNormFuncOptions()
+      .running_mean(mean)
+      .running_var(variance)
+      .weight(weight)
+      .bias(bias)
+      .momentum(momentum)
+      .eps(eps));
+  auto expected = torch::tensor({{{{ 0.0000,  0.0000},
+                                   { 0.0000,  0.0000}},
+                                  {{-0.3416,  0.5528},
+                                   { 1.4472,  2.3416}},
+                                  {{-0.6833,  1.1056},
+                                   { 2.8944,  4.6833}},
+                                  {{-1.0249,  1.6584},
+                                   { 4.3416,  7.0249}},
+                                  {{-1.3665,  2.2112},
+                                   { 5.7888,  9.3665}}},
+                                 {{{ 0.0000,  0.0000},
+                                   { 0.0000,  0.0000}},
+                                  {{-0.3416,  0.5528},
+                                   { 1.4472,  2.3416}},
+                                  {{-0.6833,  1.1056},
+                                   { 2.8944,  4.6833}},
+                                  {{-1.0249,  1.6584},
+                                   { 4.3416,  7.0249}},
+                                  {{-1.3665,  2.2112},
+                                   { 5.7888,  9.3665}}}});
+  ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
+  int num_features = 5;
+  double eps = 1e-05;
+
+  auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
+  auto output = F::instance_norm(input);
+  auto expected = torch::tensor({{{{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}}},
+                                 {{{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}}}});
+  ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm3d) {
+  int num_features = 5;
+  double eps = 1e-05;
+  double momentum = 0.1;
+
+  auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
+  auto mean = torch::arange((double)num_features);
+  auto variance = torch::arange((double)num_features);
+  auto weight = torch::arange((double)num_features);
+  auto bias = torch::arange((double)num_features);
+  auto output = F::instance_norm(
+    input,
+    F::InstanceNormFuncOptions()
+      .running_mean(mean)
+      .running_var(variance)
+      .weight(weight)
+      .bias(bias)
+      .momentum(momentum)
+      .eps(eps));
+  auto expected = torch::tensor({{{{{ 0.0000,  0.0000},
+                                    { 0.0000,  0.0000}},
+                                   {{ 0.0000,  0.0000},
+                                    { 0.0000,  0.0000}}},
+                                  {{{-0.5275, -0.0911},
+                                    { 0.3453,  0.7818}},
+                                   {{ 1.2182,  1.6547},
+                                    { 2.0911,  2.5275}}},
+                                  {{{-1.0550, -0.1822},
+                                    { 0.6907,  1.5636}},
+                                   {{ 2.4364,  3.3093},
+                                    { 4.1822,  5.0550}}},
+                                  {{{-1.5826, -0.2733},
+                                    { 1.0360,  2.3453}},
+                                   {{ 3.6547,  4.9640},
+                                    { 6.2733,  7.5826}}},
+                                  {{{-2.1101, -0.3644},
+                                    { 1.3814,  3.1271}},
+                                   {{ 4.8729,  6.6186},
+                                    { 8.3644, 10.1101}}}},
+                                 {{{{ 0.0000,  0.0000},
+                                    { 0.0000,  0.0000}},
+                                   {{ 0.0000,  0.0000},
+                                    { 0.0000,  0.0000}}},
+                                  {{{-0.5275, -0.0911},
+                                    { 0.3453,  0.7818}},
+                                   {{ 1.2182,  1.6547},
+                                    { 2.0911,  2.5275}}},
+                                  {{{-1.0550, -0.1822},
+                                    { 0.6907,  1.5636}},
+                                   {{ 2.4364,  3.3093},
+                                    { 4.1822,  5.0550}}},
+                                  {{{-1.5826, -0.2733},
+                                    { 1.0360,  2.3453}},
+                                   {{ 3.6547,  4.9640},
+                                    { 6.2733,  7.5826}}},
+                                  {{{-2.1101, -0.3644},
+                                    { 1.3814,  3.1271}},
+                                   {{ 4.8729,  6.6186},
+                                    { 8.3644, 10.1101}}}}});
+  ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
+  int num_features = 5;
+  double eps = 1e-05;
+
+  auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
+  auto output = F::instance_norm(input);
+  auto expected = torch::tensor({{{{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}}},
+                                 {{{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}}}});
+  ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
 TEST_F(FunctionalTest, Interpolate) {
   {
     // 1D interpolation
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 96a1b0e..33fdb90 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -1427,7 +1427,7 @@
 }
 
 TEST_F(ModulesTest, BatchNorm1dStateful) {
-  BatchNorm1d bn(BatchNorm1dOptions(5));
+  BatchNorm1d bn(5);
 
   ASSERT_TRUE(bn->options.track_running_stats());
 
@@ -1464,20 +1464,30 @@
 }
 
 TEST_F(ModulesTest, BatchNorm1d) {
-  BatchNorm1d bn(BatchNorm1dOptions(5));
+  BatchNorm1d bn(5);
   bn->eval();
 
-  auto input = torch::randn({2, 5}, torch::requires_grad());
+  auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
   auto output = bn->forward(input);
+  auto expected = torch::tensor({{{ 0.0000,  1.0000},
+                                  { 2.0000,  3.0000},
+                                  { 4.0000,  5.0000},
+                                  { 6.0000,  7.0000},
+                                  { 8.0000,  9.0000}},
+                                 {{10.0000, 10.9999},
+                                  {11.9999, 12.9999},
+                                  {13.9999, 14.9999},
+                                  {15.9999, 16.9999},
+                                  {17.9999, 18.9999}}});
+  ASSERT_TRUE(output.allclose(expected));
   auto s = output.sum();
   s.backward();
 
   ASSERT_EQ(input.sizes(), input.grad().sizes());
-  ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
 }
 
 TEST_F(ModulesTest, BatchNorm2dStateful) {
-  BatchNorm2d bn(BatchNorm2dOptions(5));
+  BatchNorm2d bn(5);
 
   ASSERT_TRUE(bn->options.track_running_stats());
 
@@ -1514,20 +1524,40 @@
 }
 
 TEST_F(ModulesTest, BatchNorm2d) {
-  BatchNorm2d bn(BatchNorm2dOptions(5));
+  BatchNorm2d bn(5);
   bn->eval();
 
-  auto input = torch::randn({2, 5, 4, 4}, torch::requires_grad());
+  auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
   auto output = bn->forward(input);
+  auto expected = torch::tensor({{{{ 0.0000,  1.0000},
+                                   { 2.0000,  3.0000}},
+                                  {{ 4.0000,  5.0000},
+                                   { 6.0000,  7.0000}},
+                                  {{ 8.0000,  9.0000},
+                                   {10.0000, 10.9999}},
+                                  {{11.9999, 12.9999},
+                                   {13.9999, 14.9999}},
+                                  {{15.9999, 16.9999},
+                                   {17.9999, 18.9999}}},
+                                 {{{19.9999, 20.9999},
+                                   {21.9999, 22.9999}},
+                                  {{23.9999, 24.9999},
+                                   {25.9999, 26.9999}},
+                                  {{27.9999, 28.9999},
+                                   {29.9998, 30.9998}},
+                                  {{31.9998, 32.9998},
+                                   {33.9998, 34.9998}},
+                                  {{35.9998, 36.9998},
+                                   {37.9998, 38.9998}}}});
+  ASSERT_TRUE(output.allclose(expected));
   auto s = output.sum();
   s.backward();
   
   ASSERT_EQ(input.sizes(), input.grad().sizes());
-  ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4})));
 }
 
 TEST_F(ModulesTest, BatchNorm3dStateful) {
-  BatchNorm3d bn(BatchNorm3dOptions(5));
+  BatchNorm3d bn(5);
 
   ASSERT_TRUE(bn->options.track_running_stats());
 
@@ -1564,16 +1594,276 @@
 }
 
 TEST_F(ModulesTest, BatchNorm3d) {
-  BatchNorm3d bn(BatchNorm3dOptions(5));
+  BatchNorm3d bn(5);
   bn->eval();
 
-  auto input = torch::randn({2, 5, 4, 4, 4}, torch::requires_grad());
+  auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
   auto output = bn->forward(input);
+  auto expected = torch::tensor({{{{{ 0.0000,  1.0000},
+                                    { 2.0000,  3.0000}},
+                                   {{ 4.0000,  5.0000},
+                                    { 6.0000,  7.0000}}},
+                                  {{{ 8.0000,  9.0000},
+                                    {10.0000, 10.9999}},
+                                   {{11.9999, 12.9999},
+                                    {13.9999, 14.9999}}},
+                                  {{{15.9999, 16.9999},
+                                    {17.9999, 18.9999}},
+                                   {{19.9999, 20.9999},
+                                    {21.9999, 22.9999}}},
+                                  {{{23.9999, 24.9999},
+                                    {25.9999, 26.9999}},
+                                   {{27.9999, 28.9999},
+                                    {29.9998, 30.9998}}},
+                                  {{{31.9998, 32.9998},
+                                    {33.9998, 34.9998}},
+                                   {{35.9998, 36.9998},
+                                    {37.9998, 38.9998}}}},
+                                 {{{{39.9998, 40.9998},
+                                    {41.9998, 42.9998}},
+                                   {{43.9998, 44.9998},
+                                    {45.9998, 46.9998}}},
+                                  {{{47.9998, 48.9998},
+                                    {49.9997, 50.9997}},
+                                   {{51.9997, 52.9997},
+                                    {53.9997, 54.9997}}},
+                                  {{{55.9997, 56.9997},
+                                    {57.9997, 58.9997}},
+                                   {{59.9997, 60.9997},
+                                    {61.9997, 62.9997}}},
+                                  {{{63.9997, 64.9997},
+                                    {65.9997, 66.9997}},
+                                   {{67.9997, 68.9997},
+                                    {69.9996, 70.9996}}},
+                                  {{{71.9996, 72.9996},
+                                    {73.9996, 74.9996}},
+                                   {{75.9996, 76.9996},
+                                    {77.9996, 78.9996}}}}});
+  ASSERT_TRUE(output.allclose(expected));
   auto s = output.sum();
   s.backward();
   
   ASSERT_EQ(input.sizes(), input.grad().sizes());
-  ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4, 4})));
+}
+
+TEST_F(ModulesTest, InstanceNorm1dStateful) {
+  InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
+
+  ASSERT_TRUE(instance_norm->options.track_running_stats());
+
+  ASSERT_TRUE(instance_norm->running_mean.defined());
+  ASSERT_EQ(instance_norm->running_mean.dim(), 1);
+  ASSERT_EQ(instance_norm->running_mean.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->running_var.defined());
+  ASSERT_EQ(instance_norm->running_var.dim(), 1);
+  ASSERT_EQ(instance_norm->running_var.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
+  ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
+
+  ASSERT_TRUE(instance_norm->options.affine());
+
+  ASSERT_TRUE(instance_norm->weight.defined());
+  ASSERT_EQ(instance_norm->weight.dim(), 1);
+  ASSERT_EQ(instance_norm->weight.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->bias.defined());
+  ASSERT_EQ(instance_norm->bias.dim(), 1);
+  ASSERT_EQ(instance_norm->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, InstanceNorm1dStateless) {
+  InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
+
+  ASSERT_FALSE(instance_norm->running_mean.defined());
+  ASSERT_FALSE(instance_norm->running_var.defined());
+  ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
+  ASSERT_FALSE(instance_norm->weight.defined());
+  ASSERT_FALSE(instance_norm->bias.defined());
+}
+
+TEST_F(ModulesTest, InstanceNorm1d) {
+  InstanceNorm1d instance_norm(5);
+  instance_norm->eval();
+
+  auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
+  auto output = instance_norm->forward(input);
+  auto expected = torch::tensor({{{-1.0000, 1.0000},
+                                  {-1.0000, 1.0000},
+                                  {-1.0000, 1.0000},
+                                  {-1.0000, 1.0000},
+                                  {-1.0000, 1.0000}},
+                                 {{-1.0000, 1.0000},
+                                  {-1.0000, 1.0000},
+                                  {-1.0000, 1.0000},
+                                  {-1.0000, 1.0000},
+                                  {-1.0000, 1.0000}}});
+  ASSERT_TRUE(output.allclose(expected, 1e-3));
+  auto s = output.sum();
+  s.backward();
+
+  ASSERT_EQ(input.sizes(), input.grad().sizes());
+}
+
+TEST_F(ModulesTest, InstanceNorm2dStateful) {
+  InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
+
+  ASSERT_TRUE(instance_norm->options.track_running_stats());
+
+  ASSERT_TRUE(instance_norm->running_mean.defined());
+  ASSERT_EQ(instance_norm->running_mean.dim(), 1);
+  ASSERT_EQ(instance_norm->running_mean.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->running_var.defined());
+  ASSERT_EQ(instance_norm->running_var.dim(), 1);
+  ASSERT_EQ(instance_norm->running_var.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
+  ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
+
+  ASSERT_TRUE(instance_norm->options.affine());
+
+  ASSERT_TRUE(instance_norm->weight.defined());
+  ASSERT_EQ(instance_norm->weight.dim(), 1);
+  ASSERT_EQ(instance_norm->weight.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->bias.defined());
+  ASSERT_EQ(instance_norm->bias.dim(), 1);
+  ASSERT_EQ(instance_norm->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, InstanceNorm2dStateless) {
+  InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
+
+  ASSERT_FALSE(instance_norm->running_mean.defined());
+  ASSERT_FALSE(instance_norm->running_var.defined());
+  ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
+  ASSERT_FALSE(instance_norm->weight.defined());
+  ASSERT_FALSE(instance_norm->bias.defined());
+}
+
+TEST_F(ModulesTest, InstanceNorm2d) {
+  InstanceNorm2d instance_norm(5);
+  instance_norm->eval();
+
+  auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
+  auto output = instance_norm->forward(input);
+  auto expected = torch::tensor({{{{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}}},
+                                 {{{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}},
+                                  {{-1.3416, -0.4472},
+                                   { 0.4472,  1.3416}}}});
+  ASSERT_TRUE(output.allclose(expected, 1e-3));
+  auto s = output.sum();
+  s.backward();
+
+  ASSERT_EQ(input.sizes(), input.grad().sizes());
+}
+
+TEST_F(ModulesTest, InstanceNorm3dStateful) {
+  InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
+
+  ASSERT_TRUE(instance_norm->options.track_running_stats());
+
+  ASSERT_TRUE(instance_norm->running_mean.defined());
+  ASSERT_EQ(instance_norm->running_mean.dim(), 1);
+  ASSERT_EQ(instance_norm->running_mean.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->running_var.defined());
+  ASSERT_EQ(instance_norm->running_var.dim(), 1);
+  ASSERT_EQ(instance_norm->running_var.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
+  ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
+
+  ASSERT_TRUE(instance_norm->options.affine());
+
+  ASSERT_TRUE(instance_norm->weight.defined());
+  ASSERT_EQ(instance_norm->weight.dim(), 1);
+  ASSERT_EQ(instance_norm->weight.size(0), 5);
+
+  ASSERT_TRUE(instance_norm->bias.defined());
+  ASSERT_EQ(instance_norm->bias.dim(), 1);
+  ASSERT_EQ(instance_norm->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, InstanceNorm3dStateless) {
+  InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
+
+  ASSERT_FALSE(instance_norm->running_mean.defined());
+  ASSERT_FALSE(instance_norm->running_var.defined());
+  ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
+  ASSERT_FALSE(instance_norm->weight.defined());
+  ASSERT_FALSE(instance_norm->bias.defined());
+}
+
+TEST_F(ModulesTest, InstanceNorm3d) {
+  InstanceNorm3d instance_norm(5);
+  instance_norm->eval();
+
+  auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
+  auto output = instance_norm->forward(input);
+  auto expected = torch::tensor({{{{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}}},
+                                 {{{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}},
+                                  {{{-1.5275, -1.0911},
+                                    {-0.6547, -0.2182}},
+                                   {{ 0.2182,  0.6547},
+                                    { 1.0911,  1.5275}}}}});
+  ASSERT_TRUE(output.allclose(expected, 1e-3));
+  auto s = output.sum();
+  s.backward();
+
+  ASSERT_EQ(input.sizes(), input.grad().sizes());
 }
 
 TEST_F(ModulesTest, Linear_CUDA) {
@@ -3178,6 +3468,30 @@
       "torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
 }
 
+TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
+  ASSERT_EQ(
+      c10::str(InstanceNorm1d(
+          InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false)
+          .track_running_stats(true))),
+      "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
+  ASSERT_EQ(
+      c10::str(InstanceNorm2d(
+          InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false)
+          .track_running_stats(true))),
+      "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
+  ASSERT_EQ(
+      c10::str(InstanceNorm3d(
+          InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false)
+          .track_running_stats(true))),
+      "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
 TEST_F(ModulesTest, PrettyPrintLayerNorm) {
   ASSERT_EQ(
     c10::str(LayerNorm(LayerNormOptions({2, 2}))),
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 7a80b4c..55d14be 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -75,9 +75,9 @@
 torch.nn.BatchNorm3d|Yes|No
 torch.nn.GroupNorm|Yes|No
 torch.nn.SyncBatchNorm|No|No
-torch.nn.InstanceNorm1d|No|No
-torch.nn.InstanceNorm2d|No|No
-torch.nn.InstanceNorm3d|No|No
+torch.nn.InstanceNorm1d|Yes|No
+torch.nn.InstanceNorm2d|Yes|No
+torch.nn.InstanceNorm3d|Yes|No
 torch.nn.LayerNorm|Yes|No
 torch.nn.LocalResponseNorm|Yes|No
 torch.nn.CrossMapLRN2d|Yes|No
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 80d7dbc..1e132db 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -221,6 +221,7 @@
         "torch/csrc/api/src/nn/modules/activation.cpp",
         "torch/csrc/api/src/nn/modules/batchnorm.cpp",
         "torch/csrc/api/src/nn/modules/normalization.cpp",
+        "torch/csrc/api/src/nn/modules/instancenorm.cpp",
         "torch/csrc/api/src/nn/modules/conv.cpp",
         "torch/csrc/api/src/nn/modules/dropout.cpp",
         "torch/csrc/api/src/nn/modules/distance.cpp",
@@ -239,6 +240,7 @@
         "torch/csrc/api/src/nn/options/batchnorm.cpp",
         "torch/csrc/api/src/nn/options/conv.cpp",
         "torch/csrc/api/src/nn/options/dropout.cpp",
+        "torch/csrc/api/src/nn/options/instancenorm.cpp",
         "torch/csrc/api/src/nn/options/linear.cpp",
         "torch/csrc/api/src/nn/options/normalization.cpp",
         "torch/csrc/api/src/nn/options/embedding.cpp",
diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h
index 8fe4af7..0a9db5c 100644
--- a/torch/csrc/api/include/torch/nn/functional.h
+++ b/torch/csrc/api/include/torch/nn/functional.h
@@ -14,3 +14,4 @@
 #include <torch/nn/functional/pooling.h>
 #include <torch/nn/functional/upsampling.h>
 #include <torch/nn/functional/vision.h>
+#include <torch/nn/functional/instancenorm.h>
diff --git a/torch/csrc/api/include/torch/nn/functional/instancenorm.h b/torch/csrc/api/include/torch/nn/functional/instancenorm.h
new file mode 100644
index 0000000..b5534d7
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/functional/instancenorm.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <torch/nn/options/instancenorm.h>
+
+namespace torch {
+namespace nn {
+namespace functional {
+
+namespace detail {
+inline Tensor instance_norm(const Tensor& input, const Tensor& running_mean,
+    const Tensor& running_var, const Tensor& weight, const Tensor& bias,
+    bool use_input_stats, double momentum, double eps) {
+
+  return torch::instance_norm(
+    input, weight, bias, running_mean, running_var,
+    use_input_stats, momentum, eps, at::globalContext().userEnabledCuDNN()
+  );
+}
+} // namespace detail
+
+inline Tensor instance_norm(const Tensor& input, const InstanceNormFuncOptions& options = {}) {
+  return detail::instance_norm(
+    input, options.running_mean(),
+    options.running_var(), options.weight(), options.bias(),
+    options.use_input_stats(), options.momentum(), options.eps());
+}
+
+} // namespace functional
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h
index 2239e02..57a9213 100644
--- a/torch/csrc/api/include/torch/nn/modules.h
+++ b/torch/csrc/api/include/torch/nn/modules.h
@@ -9,6 +9,7 @@
 
 // Layers
 #include <torch/nn/modules/batchnorm.h>
+#include <torch/nn/modules/instancenorm.h>
 #include <torch/nn/modules/conv.h>
 #include <torch/nn/modules/dropout.h>
 #include <torch/nn/modules/distance.h>
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index 46caf6a..9272c42 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -1,12 +1,16 @@
 #pragma once
 
 #include <torch/nn/cloneable.h>
+#include <torch/nn/functional/batchnorm.h>
 #include <torch/nn/options/batchnorm.h>
+#include <torch/nn/init.h>
 #include <torch/nn/pimpl.h>
 #include <torch/types.h>
 
 #include <cstdint>
 
+namespace F = torch::nn::functional;
+
 namespace torch {
 namespace nn {
 
@@ -77,28 +81,55 @@
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-/// Base class for all (dimension-specialized) batchnorm modules.
-template <size_t D, typename Derived>
-class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
+/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
+template <size_t D, typename Derived, typename DerivedOptions>
+class NormImplBase : public torch::nn::Cloneable<Derived> {
  protected:
   virtual void _check_input_dim(const Tensor& input) = 0;
 
  public:
-  explicit BatchNormImplBase(const BatchNormOptions& options_);
+  NormImplBase(const DerivedOptions& options_) : options(options_) {
+    reset();
+  }
 
-  Tensor forward(const Tensor& input);
+  void reset() override {
+    if (options.affine()) {
+      weight = this->register_parameter("weight", torch::empty({options.num_features()}));
+      bias = this->register_parameter("bias", torch::empty({options.num_features()}));
+    } else {
+      weight = this->register_parameter("weight", Tensor());
+      bias = this->register_parameter("bias", Tensor());
+    }
+    if (options.track_running_stats()) {
+      running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
+      running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
+      num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
+    } else {
+      running_mean = this->register_buffer("running_mean", Tensor());
+      running_var = this->register_buffer("running_var", Tensor());
+      num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
+    }
+    reset_parameters();
+  }
 
-  void reset() override;
+  void reset_running_stats() {
+    if (options.track_running_stats()) {
+      running_mean.zero_();
+      running_var.fill_(1);
+      num_batches_tracked.zero_();
+    }
+  }
 
-  void reset_running_stats();
-
-  void reset_parameters();
-
-  /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
-  void pretty_print(std::ostream& stream) const override;
+  void reset_parameters() {
+    reset_running_stats();
+    if (options.affine()) {
+      torch::nn::init::ones_(weight);
+      torch::nn::init::zeros_(bias);
+    }
+  }
 
   /// The options with which this module was constructed.
-  BatchNormOptions options;
+  DerivedOptions options;
 
   /// The learned weight.
   /// Only defined if the `affine` option was `true` upon construction.
@@ -121,6 +152,47 @@
   Tensor num_batches_tracked;
 };
 
+/// Base class for all (dimension-specialized) batchnorm modules.
+template <size_t D, typename Derived>
+class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
+ public:
+  using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;
+
+  Tensor forward(const Tensor& input) {
+    this->_check_input_dim(input);
+    double exponential_average_factor;
+    if (this->options.momentum() == c10::nullopt) {
+      exponential_average_factor = 0.0;
+    } else {
+      exponential_average_factor = this->options.momentum().value();
+    }
+
+    if (this->is_training() && this->options.track_running_stats()) {
+      if (this->num_batches_tracked.defined()) {
+        this->num_batches_tracked += 1;
+        if (this->options.momentum() == c10::nullopt) {  // use cumulative moving average
+          exponential_average_factor = 1.0 / this->num_batches_tracked.template item<double>();
+        } else {  // use exponential moving average
+          exponential_average_factor = this->options.momentum().value();
+        }
+      }
+    }
+
+    return F::detail::batch_norm(
+        input,
+        this->running_mean,
+        this->running_var,
+        this->weight,
+        this->bias,
+        this->is_training() || !this->options.track_running_stats(),
+        /*momentum=*/exponential_average_factor,
+        this->options.eps());
+  }
+
+  /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+};
+
 /// Applies the BatchNorm1d function.
 /// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
 /// about the exact behavior of this module.
diff --git a/torch/csrc/api/include/torch/nn/modules/instancenorm.h b/torch/csrc/api/include/torch/nn/modules/instancenorm.h
new file mode 100644
index 0000000..21164d9
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/modules/instancenorm.h
@@ -0,0 +1,66 @@
+#pragma once
+
+#include <torch/nn/modules/batchnorm.h>
+#include <torch/nn/options/instancenorm.h>
+
+namespace torch {
+namespace nn {
+
+/// Base class for all (dimension-specialized) instance norm modules
+template <size_t D, typename Derived>
+class InstanceNormImpl : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
+ public:
+  using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
+
+  Tensor forward(const Tensor& input) {
+    this->_check_input_dim(input);
+    return F::detail::instance_norm(
+      input, this->running_mean, this->running_var, this->weight, this->bias,
+      this->is_training() || !this->options.track_running_stats(), this->options.momentum(), this->options.eps());
+  }
+
+  /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
+  void pretty_print(std::ostream& stream) const override;
+};
+
+/// Applies the InstanceNorm1d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm1d to learn
+/// about the exact behavior of this module.
+class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> {
+ protected:
+  virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+  using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
+};
+
+TORCH_MODULE(InstanceNorm1d);
+
+/// Applies the InstanceNorm2d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm2d to learn
+/// about the exact behavior of this module.
+class TORCH_API InstanceNorm2dImpl : public InstanceNormImpl<2, InstanceNorm2dImpl> {
+ protected:
+  virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+  using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
+};
+
+TORCH_MODULE(InstanceNorm2d);
+
+/// Applies the InstanceNorm3d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm3d to learn
+/// about the exact behavior of this module.
+class TORCH_API InstanceNorm3dImpl : public InstanceNormImpl<3, InstanceNorm3dImpl> {
+ protected:
+  virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+  using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
+};
+
+TORCH_MODULE(InstanceNorm3d);
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h
index 5495c07..1770191 100644
--- a/torch/csrc/api/include/torch/nn/options/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h
@@ -43,7 +43,7 @@
 
 namespace functional {
 
-/// Options for the `BatchNorm` module.
+/// Options for the `BatchNorm` functional.
 struct TORCH_API BatchNormFuncOptions {
   TORCH_ARG(Tensor, weight) = Tensor();
 
diff --git a/torch/csrc/api/include/torch/nn/options/instancenorm.h b/torch/csrc/api/include/torch/nn/options/instancenorm.h
new file mode 100644
index 0000000..8c536c2
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/options/instancenorm.h
@@ -0,0 +1,59 @@
+#pragma once
+
+#include <torch/arg.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/nn/options/batchnorm.h>
+#include <torch/types.h>
+
+namespace torch {
+namespace nn {
+
+/// Options for the `InstanceNorm` module.
+struct TORCH_API InstanceNormOptions {
+  /* implicit */ InstanceNormOptions(int64_t num_features);
+
+  /// The number of features of the input tensor.
+  TORCH_ARG(int64_t, num_features);
+
+  /// The epsilon value added for numerical stability.
+  TORCH_ARG(double, eps) = 1e-5;
+
+  /// A momentum multiplier for the mean and variance.
+  TORCH_ARG(double, momentum) = 0.1;
+
+  /// Whether to learn a scale and bias that are applied in an affine
+  /// transformation on the input.
+  TORCH_ARG(bool, affine) = false;
+
+  /// Whether to store and update batch statistics (mean and variance) in the
+  /// module.
+  TORCH_ARG(bool, track_running_stats) = false;
+};
+
+using InstanceNorm1dOptions = InstanceNormOptions;
+using InstanceNorm2dOptions = InstanceNormOptions;
+using InstanceNorm3dOptions = InstanceNormOptions;
+
+namespace functional {
+
+/// Options for the `InstanceNorm` functional.
+struct TORCH_API InstanceNormFuncOptions {
+  TORCH_ARG(Tensor, running_mean) = Tensor();
+
+  TORCH_ARG(Tensor, running_var) = Tensor();
+
+  TORCH_ARG(Tensor, weight) = Tensor();
+
+  TORCH_ARG(Tensor, bias) = Tensor();
+
+  TORCH_ARG(bool, use_input_stats) = true;
+
+  TORCH_ARG(double, momentum) = 0.1;
+
+  TORCH_ARG(double, eps) = 1e-5;
+};
+
+} // namespace functional
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index 2311850..ec05a9c 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -3,7 +3,6 @@
 
 #include <torch/cuda.h>
 #include <torch/types.h>
-#include <torch/nn/init.h>
 
 #include <c10/util/Exception.h>
 
@@ -17,7 +16,7 @@
 namespace torch {
 namespace nn {
 
-BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
+BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
   TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
              "Use BatchNorm{1,2,3}d instead.");
   reset();
@@ -78,93 +77,17 @@
       torch::cuda::cudnn_is_available());
 }
 
-template <size_t D, typename Derived>
-BatchNormImplBase<D, Derived>::BatchNormImplBase(const BatchNormOptions& options_)
-    : options(options_) {
-  reset();
-}
+// ===========================================================================
 
-template <size_t D, typename Derived>
-void BatchNormImplBase<D, Derived>::reset() {
-  if (options.affine()) {
-    weight = this->register_parameter("weight", torch::empty({options.num_features()}));
-    bias = this->register_parameter("bias", torch::empty({options.num_features()}));
-  } else {
-    weight = this->register_parameter("weight", Tensor());
-    bias = this->register_parameter("bias", Tensor());
-  }
-  if (options.track_running_stats()) {
-    running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
-    running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
-    num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
-  } else {
-    running_mean = this->register_buffer("running_mean", Tensor());
-    running_var = this->register_buffer("running_var", Tensor());
-    num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
-  }
-  reset_parameters();
-}
-
-template <size_t D, typename Derived>
-void BatchNormImplBase<D, Derived>::reset_running_stats() {
-  if (options.track_running_stats()) {
-    running_mean.zero_();
-    running_var.fill_(1);
-    num_batches_tracked.zero_();
-  }
-}
-
-template <size_t D, typename Derived>
-void BatchNormImplBase<D, Derived>::reset_parameters() {
-  reset_running_stats();
-  if (options.affine()) {
-    torch::nn::init::ones_(weight);
-    torch::nn::init::zeros_(bias);
-  }
-}
-
-template <size_t D, typename Derived>
+template <size_t D, typename Derived> 
 void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
   stream << std::boolalpha
          << "torch::nn::BatchNorm" << D << "d("
-         << options.num_features() << ", "
-         << "eps=" << options.eps() << ", "
-         << "momentum=" << options.momentum().value() << ", "
-         << "affine=" << options.affine() << ", "
-         << "track_running_stats=" << options.track_running_stats() << ")";
-}
-
-template <size_t D, typename Derived>
-Tensor BatchNormImplBase<D, Derived>::forward(const Tensor& input) {
-  _check_input_dim(input);
-
-  double exponential_average_factor;
-  if (options.momentum() == c10::nullopt) {
-    exponential_average_factor = 0.0;
-  } else {
-    exponential_average_factor = options.momentum().value();
-  }
-
-  if (this->is_training() && options.track_running_stats()) {
-    if (num_batches_tracked.defined()) {
-      num_batches_tracked += 1;
-      if (options.momentum() == c10::nullopt) {  // use cumulative moving average
-        exponential_average_factor = 1.0 / num_batches_tracked.item<double>();
-      } else {  // use exponential moving average
-        exponential_average_factor = options.momentum().value();
-      }
-    }
-  }
-
-  return F::detail::batch_norm(
-      input,
-      running_mean,
-      running_var,
-      weight,
-      bias,
-      this->is_training() || !options.track_running_stats(),
-      /*momentum=*/exponential_average_factor,
-      options.eps());
+         << this->options.num_features() << ", "
+         << "eps=" << this->options.eps() << ", "
+         << "momentum=" << this->options.momentum().value() << ", "
+         << "affine=" << this->options.affine() << ", "
+         << "track_running_stats=" << this->options.track_running_stats() << ")";
 }
 
 void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
diff --git a/torch/csrc/api/src/nn/modules/instancenorm.cpp b/torch/csrc/api/src/nn/modules/instancenorm.cpp
new file mode 100644
index 0000000..c400c94
--- /dev/null
+++ b/torch/csrc/api/src/nn/modules/instancenorm.cpp
@@ -0,0 +1,57 @@
+#include <torch/nn/functional/instancenorm.h>
+#include <torch/nn/modules/instancenorm.h>
+
+namespace F = torch::nn::functional;
+
+namespace torch {
+namespace nn {
+
+template <size_t D, typename Derived>
+void InstanceNormImpl<D, Derived>::pretty_print(std::ostream& stream) const {
+  stream << std::boolalpha
+         << "torch::nn::InstanceNorm" << D << "d("
+         << this->options.num_features() << ", "
+         << "eps=" << this->options.eps() << ", "
+         << "momentum=" << this->options.momentum() << ", "
+         << "affine=" << this->options.affine() << ", "
+         << "track_running_stats=" << this->options.track_running_stats() << ")";
+}
+
+void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
+  if (input.dim() == 2) {
+    TORCH_CHECK(
+      false,
+      "InstanceNorm1d returns 0-filled tensor to 2D tensor.",
+      "This is because InstanceNorm1d reshapes inputs to",
+      "(1, N * C, ...) from (N, C,...) and this makes",
+      "variances 0.");
+  }
+  if (input.dim() != 3) {
+    TORCH_CHECK(
+      false,
+      "expected 3D input (got ", input.dim(), "D input)");
+  }
+}
+
+void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) {
+  if (input.dim() != 4) {
+    TORCH_CHECK(
+      false,
+      "expected 4D input (got ", input.dim(), "D input)");
+  }
+}
+
+void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) {
+  if (input.dim() != 5) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+    TORCH_CHECK(
+      false,
+      "expected 5D input (got ", input.dim(), "D input)");
+  }
+}
+
+template class InstanceNormImpl<1, InstanceNorm1dImpl>;
+template class InstanceNormImpl<2, InstanceNorm2dImpl>;
+template class InstanceNormImpl<3, InstanceNorm3dImpl>;
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/options/instancenorm.cpp b/torch/csrc/api/src/nn/options/instancenorm.cpp
new file mode 100644
index 0000000..930fc27
--- /dev/null
+++ b/torch/csrc/api/src/nn/options/instancenorm.cpp
@@ -0,0 +1,9 @@
+#include <torch/nn/options/instancenorm.h>
+
+namespace torch {
+namespace nn {
+
+InstanceNormOptions::InstanceNormOptions(int64_t num_features) : num_features_(num_features) {}
+
+} // namespace nn
+} // namespace torch