InstanceNorm{1,2,3}d (#28790)
Summary:
Hi yf225,
I have a few doubts related to implementation:
1) What tests do I have to write?
2) What does _load_state_from_dict does?
3) Do I need to override reset() function as I can not see it's utility?
4) InstanceNormOptions could be removed with BatchNormOptions, but I find that
`track_running_status` is not defined instead `stateful` is defined.
InstanceNorm{1,2,3}d https://github.com/pytorch/pytorch/issues/25883
Pull Request resolved: https://github.com/pytorch/pytorch/pull/28790
Differential Revision: D18588666
Pulled By: yf225
fbshipit-source-id: bb9b81f01f62c3fc8765fa0ba0716768087ee155
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index ba36dd0..c05af39 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -566,6 +566,7 @@
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/activation.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/normalization.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/nn/modules/instancenorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/conv.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/dropout.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/modules/distance.cpp
@@ -583,11 +584,11 @@
${TORCH_SRC_DIR}/csrc/api/src/nn/options/activation.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/batchnorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/embedding.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/nn/options/instancenorm.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/conv.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/dropout.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/linear.cpp
- ${TORCH_SRC_DIR}/csrc/api/src/nn/options/normalization.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/padding.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/pooling.cpp
${TORCH_SRC_DIR}/csrc/api/src/nn/options/rnn.cpp
diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp
index d4aca98..981c155 100644
--- a/test/cpp/api/functional.cpp
+++ b/test/cpp/api/functional.cpp
@@ -1579,6 +1579,236 @@
ASSERT_TRUE(output.allclose(expected));
}
+TEST_F(FunctionalTest, InstanceNorm1d) {
+ int num_features = 5;
+ double eps = 1e-05;
+ double momentum = 0.1;
+
+ auto input = torch::arange(40.).view({2, 5, 4});
+ auto mean = torch::arange(5.);
+ auto variance = torch::arange(5.);
+ auto weight = torch::arange((double)num_features);
+ auto bias = torch::arange((double)num_features);
+ auto output = F::instance_norm(
+ input,
+ F::InstanceNormFuncOptions()
+ .running_mean(mean)
+ .running_var(variance)
+ .weight(weight)
+ .bias(bias)
+ .momentum(momentum)
+ .eps(eps));
+ auto expected = torch::tensor({{{ 0.0000, 0.0000, 0.0000, 0.0000},
+ {-0.3416, 0.5528, 1.4472, 2.3416},
+ {-0.6833, 1.1056, 2.8944, 4.6833},
+ {-1.0249, 1.6584, 4.3416, 7.0249},
+ {-1.3665, 2.2112, 5.7888, 9.3665}},
+ {{ 0.0000, 0.0000, 0.0000, 0.0000},
+ {-0.3416, 0.5528, 1.4472, 2.3416},
+ {-0.6833, 1.1056, 2.8944, 4.6833},
+ {-1.0249, 1.6584, 4.3416, 7.0249},
+ {-1.3665, 2.2112, 5.7888, 9.3665}}});
+ ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm1dDefaultOptions) {
+ auto input = torch::arange(40.).view({2, 5, 4});
+ auto output = F::instance_norm(input);
+ auto expected = torch::tensor({{{-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416}},
+ {{-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416},
+ {-1.3416, -0.4472, 0.4472, 1.3416}}});
+ ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm2d) {
+ int num_features = 5;
+ double eps = 1e-05;
+ double momentum = 0.1;
+
+ auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
+ auto mean = torch::arange((double)num_features);
+ auto variance = torch::arange((double)num_features);
+ auto weight = torch::arange((double)num_features);
+ auto bias = torch::arange((double)num_features);
+ auto output = F::instance_norm(
+ input,
+ F::InstanceNormFuncOptions()
+ .running_mean(mean)
+ .running_var(variance)
+ .weight(weight)
+ .bias(bias)
+ .momentum(momentum)
+ .eps(eps));
+ auto expected = torch::tensor({{{{ 0.0000, 0.0000},
+ { 0.0000, 0.0000}},
+ {{-0.3416, 0.5528},
+ { 1.4472, 2.3416}},
+ {{-0.6833, 1.1056},
+ { 2.8944, 4.6833}},
+ {{-1.0249, 1.6584},
+ { 4.3416, 7.0249}},
+ {{-1.3665, 2.2112},
+ { 5.7888, 9.3665}}},
+ {{{ 0.0000, 0.0000},
+ { 0.0000, 0.0000}},
+ {{-0.3416, 0.5528},
+ { 1.4472, 2.3416}},
+ {{-0.6833, 1.1056},
+ { 2.8944, 4.6833}},
+ {{-1.0249, 1.6584},
+ { 4.3416, 7.0249}},
+ {{-1.3665, 2.2112},
+ { 5.7888, 9.3665}}}});
+ ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm2dDefaultOptions) {
+ int num_features = 5;
+ double eps = 1e-05;
+
+ auto input = torch::arange(2. * num_features * 2 * 2).view({2, num_features, 2, 2});
+ auto output = F::instance_norm(input);
+ auto expected = torch::tensor({{{{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}}},
+ {{{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}}}});
+ ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm3d) {
+ int num_features = 5;
+ double eps = 1e-05;
+ double momentum = 0.1;
+
+ auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
+ auto mean = torch::arange((double)num_features);
+ auto variance = torch::arange((double)num_features);
+ auto weight = torch::arange((double)num_features);
+ auto bias = torch::arange((double)num_features);
+ auto output = F::instance_norm(
+ input,
+ F::InstanceNormFuncOptions()
+ .running_mean(mean)
+ .running_var(variance)
+ .weight(weight)
+ .bias(bias)
+ .momentum(momentum)
+ .eps(eps));
+ auto expected = torch::tensor({{{{{ 0.0000, 0.0000},
+ { 0.0000, 0.0000}},
+ {{ 0.0000, 0.0000},
+ { 0.0000, 0.0000}}},
+ {{{-0.5275, -0.0911},
+ { 0.3453, 0.7818}},
+ {{ 1.2182, 1.6547},
+ { 2.0911, 2.5275}}},
+ {{{-1.0550, -0.1822},
+ { 0.6907, 1.5636}},
+ {{ 2.4364, 3.3093},
+ { 4.1822, 5.0550}}},
+ {{{-1.5826, -0.2733},
+ { 1.0360, 2.3453}},
+ {{ 3.6547, 4.9640},
+ { 6.2733, 7.5826}}},
+ {{{-2.1101, -0.3644},
+ { 1.3814, 3.1271}},
+ {{ 4.8729, 6.6186},
+ { 8.3644, 10.1101}}}},
+ {{{{ 0.0000, 0.0000},
+ { 0.0000, 0.0000}},
+ {{ 0.0000, 0.0000},
+ { 0.0000, 0.0000}}},
+ {{{-0.5275, -0.0911},
+ { 0.3453, 0.7818}},
+ {{ 1.2182, 1.6547},
+ { 2.0911, 2.5275}}},
+ {{{-1.0550, -0.1822},
+ { 0.6907, 1.5636}},
+ {{ 2.4364, 3.3093},
+ { 4.1822, 5.0550}}},
+ {{{-1.5826, -0.2733},
+ { 1.0360, 2.3453}},
+ {{ 3.6547, 4.9640},
+ { 6.2733, 7.5826}}},
+ {{{-2.1101, -0.3644},
+ { 1.3814, 3.1271}},
+ {{ 4.8729, 6.6186},
+ { 8.3644, 10.1101}}}}});
+ ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
+TEST_F(FunctionalTest, InstanceNorm3dDefaultOptions) {
+ int num_features = 5;
+ double eps = 1e-05;
+
+ auto input = torch::arange(2. * num_features * 2 * 2 * 2).view({2, num_features, 2, 2, 2});
+ auto output = F::instance_norm(input);
+ auto expected = torch::tensor({{{{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}}},
+ {{{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}}}});
+ ASSERT_TRUE(output.allclose(expected, 2e-04));
+}
+
TEST_F(FunctionalTest, Interpolate) {
{
// 1D interpolation
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 96a1b0e..33fdb90 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -1427,7 +1427,7 @@
}
TEST_F(ModulesTest, BatchNorm1dStateful) {
- BatchNorm1d bn(BatchNorm1dOptions(5));
+ BatchNorm1d bn(5);
ASSERT_TRUE(bn->options.track_running_stats());
@@ -1464,20 +1464,30 @@
}
TEST_F(ModulesTest, BatchNorm1d) {
- BatchNorm1d bn(BatchNorm1dOptions(5));
+ BatchNorm1d bn(5);
bn->eval();
- auto input = torch::randn({2, 5}, torch::requires_grad());
+ auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
auto output = bn->forward(input);
+ auto expected = torch::tensor({{{ 0.0000, 1.0000},
+ { 2.0000, 3.0000},
+ { 4.0000, 5.0000},
+ { 6.0000, 7.0000},
+ { 8.0000, 9.0000}},
+ {{10.0000, 10.9999},
+ {11.9999, 12.9999},
+ {13.9999, 14.9999},
+ {15.9999, 16.9999},
+ {17.9999, 18.9999}}});
+ ASSERT_TRUE(output.allclose(expected));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
- ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5})));
}
TEST_F(ModulesTest, BatchNorm2dStateful) {
- BatchNorm2d bn(BatchNorm2dOptions(5));
+ BatchNorm2d bn(5);
ASSERT_TRUE(bn->options.track_running_stats());
@@ -1514,20 +1524,40 @@
}
TEST_F(ModulesTest, BatchNorm2d) {
- BatchNorm2d bn(BatchNorm2dOptions(5));
+ BatchNorm2d bn(5);
bn->eval();
- auto input = torch::randn({2, 5, 4, 4}, torch::requires_grad());
+ auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
auto output = bn->forward(input);
+ auto expected = torch::tensor({{{{ 0.0000, 1.0000},
+ { 2.0000, 3.0000}},
+ {{ 4.0000, 5.0000},
+ { 6.0000, 7.0000}},
+ {{ 8.0000, 9.0000},
+ {10.0000, 10.9999}},
+ {{11.9999, 12.9999},
+ {13.9999, 14.9999}},
+ {{15.9999, 16.9999},
+ {17.9999, 18.9999}}},
+ {{{19.9999, 20.9999},
+ {21.9999, 22.9999}},
+ {{23.9999, 24.9999},
+ {25.9999, 26.9999}},
+ {{27.9999, 28.9999},
+ {29.9998, 30.9998}},
+ {{31.9998, 32.9998},
+ {33.9998, 34.9998}},
+ {{35.9998, 36.9998},
+ {37.9998, 38.9998}}}});
+ ASSERT_TRUE(output.allclose(expected));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
- ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4})));
}
TEST_F(ModulesTest, BatchNorm3dStateful) {
- BatchNorm3d bn(BatchNorm3dOptions(5));
+ BatchNorm3d bn(5);
ASSERT_TRUE(bn->options.track_running_stats());
@@ -1564,16 +1594,276 @@
}
TEST_F(ModulesTest, BatchNorm3d) {
- BatchNorm3d bn(BatchNorm3dOptions(5));
+ BatchNorm3d bn(5);
bn->eval();
- auto input = torch::randn({2, 5, 4, 4, 4}, torch::requires_grad());
+ auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
auto output = bn->forward(input);
+ auto expected = torch::tensor({{{{{ 0.0000, 1.0000},
+ { 2.0000, 3.0000}},
+ {{ 4.0000, 5.0000},
+ { 6.0000, 7.0000}}},
+ {{{ 8.0000, 9.0000},
+ {10.0000, 10.9999}},
+ {{11.9999, 12.9999},
+ {13.9999, 14.9999}}},
+ {{{15.9999, 16.9999},
+ {17.9999, 18.9999}},
+ {{19.9999, 20.9999},
+ {21.9999, 22.9999}}},
+ {{{23.9999, 24.9999},
+ {25.9999, 26.9999}},
+ {{27.9999, 28.9999},
+ {29.9998, 30.9998}}},
+ {{{31.9998, 32.9998},
+ {33.9998, 34.9998}},
+ {{35.9998, 36.9998},
+ {37.9998, 38.9998}}}},
+ {{{{39.9998, 40.9998},
+ {41.9998, 42.9998}},
+ {{43.9998, 44.9998},
+ {45.9998, 46.9998}}},
+ {{{47.9998, 48.9998},
+ {49.9997, 50.9997}},
+ {{51.9997, 52.9997},
+ {53.9997, 54.9997}}},
+ {{{55.9997, 56.9997},
+ {57.9997, 58.9997}},
+ {{59.9997, 60.9997},
+ {61.9997, 62.9997}}},
+ {{{63.9997, 64.9997},
+ {65.9997, 66.9997}},
+ {{67.9997, 68.9997},
+ {69.9996, 70.9996}}},
+ {{{71.9996, 72.9996},
+ {73.9996, 74.9996}},
+ {{75.9996, 76.9996},
+ {77.9996, 78.9996}}}}});
+ ASSERT_TRUE(output.allclose(expected));
auto s = output.sum();
s.backward();
ASSERT_EQ(input.sizes(), input.grad().sizes());
- ASSERT_TRUE(input.grad().allclose(torch::ones({2, 5, 4, 4, 4})));
+}
+
+TEST_F(ModulesTest, InstanceNorm1dStateful) {
+ InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(true).affine(true));
+
+ ASSERT_TRUE(instance_norm->options.track_running_stats());
+
+ ASSERT_TRUE(instance_norm->running_mean.defined());
+ ASSERT_EQ(instance_norm->running_mean.dim(), 1);
+ ASSERT_EQ(instance_norm->running_mean.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->running_var.defined());
+ ASSERT_EQ(instance_norm->running_var.dim(), 1);
+ ASSERT_EQ(instance_norm->running_var.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
+ ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
+
+ ASSERT_TRUE(instance_norm->options.affine());
+
+ ASSERT_TRUE(instance_norm->weight.defined());
+ ASSERT_EQ(instance_norm->weight.dim(), 1);
+ ASSERT_EQ(instance_norm->weight.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->bias.defined());
+ ASSERT_EQ(instance_norm->bias.dim(), 1);
+ ASSERT_EQ(instance_norm->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, InstanceNorm1dStateless) {
+ InstanceNorm1d instance_norm(InstanceNorm1dOptions(5).track_running_stats(false).affine(false));
+
+ ASSERT_FALSE(instance_norm->running_mean.defined());
+ ASSERT_FALSE(instance_norm->running_var.defined());
+ ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
+ ASSERT_FALSE(instance_norm->weight.defined());
+ ASSERT_FALSE(instance_norm->bias.defined());
+}
+
+TEST_F(ModulesTest, InstanceNorm1d) {
+ InstanceNorm1d instance_norm(5);
+ instance_norm->eval();
+
+ auto input = torch::arange(2. * 5 * 2).view({2, 5, 2}).requires_grad_();
+ auto output = instance_norm->forward(input);
+ auto expected = torch::tensor({{{-1.0000, 1.0000},
+ {-1.0000, 1.0000},
+ {-1.0000, 1.0000},
+ {-1.0000, 1.0000},
+ {-1.0000, 1.0000}},
+ {{-1.0000, 1.0000},
+ {-1.0000, 1.0000},
+ {-1.0000, 1.0000},
+ {-1.0000, 1.0000},
+ {-1.0000, 1.0000}}});
+ ASSERT_TRUE(output.allclose(expected, 1e-3));
+ auto s = output.sum();
+ s.backward();
+
+ ASSERT_EQ(input.sizes(), input.grad().sizes());
+}
+
+TEST_F(ModulesTest, InstanceNorm2dStateful) {
+ InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(true).affine(true));
+
+ ASSERT_TRUE(instance_norm->options.track_running_stats());
+
+ ASSERT_TRUE(instance_norm->running_mean.defined());
+ ASSERT_EQ(instance_norm->running_mean.dim(), 1);
+ ASSERT_EQ(instance_norm->running_mean.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->running_var.defined());
+ ASSERT_EQ(instance_norm->running_var.dim(), 1);
+ ASSERT_EQ(instance_norm->running_var.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
+ ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
+
+ ASSERT_TRUE(instance_norm->options.affine());
+
+ ASSERT_TRUE(instance_norm->weight.defined());
+ ASSERT_EQ(instance_norm->weight.dim(), 1);
+ ASSERT_EQ(instance_norm->weight.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->bias.defined());
+ ASSERT_EQ(instance_norm->bias.dim(), 1);
+ ASSERT_EQ(instance_norm->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, InstanceNorm2dStateless) {
+ InstanceNorm2d instance_norm(InstanceNorm2dOptions(5).track_running_stats(false).affine(false));
+
+ ASSERT_FALSE(instance_norm->running_mean.defined());
+ ASSERT_FALSE(instance_norm->running_var.defined());
+ ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
+ ASSERT_FALSE(instance_norm->weight.defined());
+ ASSERT_FALSE(instance_norm->bias.defined());
+}
+
+TEST_F(ModulesTest, InstanceNorm2d) {
+ InstanceNorm2d instance_norm(5);
+ instance_norm->eval();
+
+ auto input = torch::arange(2. * 5 * 2 * 2).view({2, 5, 2, 2}).requires_grad_();
+ auto output = instance_norm->forward(input);
+ auto expected = torch::tensor({{{{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}}},
+ {{{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}},
+ {{-1.3416, -0.4472},
+ { 0.4472, 1.3416}}}});
+ ASSERT_TRUE(output.allclose(expected, 1e-3));
+ auto s = output.sum();
+ s.backward();
+
+ ASSERT_EQ(input.sizes(), input.grad().sizes());
+}
+
+TEST_F(ModulesTest, InstanceNorm3dStateful) {
+ InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(true).affine(true));
+
+ ASSERT_TRUE(instance_norm->options.track_running_stats());
+
+ ASSERT_TRUE(instance_norm->running_mean.defined());
+ ASSERT_EQ(instance_norm->running_mean.dim(), 1);
+ ASSERT_EQ(instance_norm->running_mean.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->running_var.defined());
+ ASSERT_EQ(instance_norm->running_var.dim(), 1);
+ ASSERT_EQ(instance_norm->running_var.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->num_batches_tracked.defined());
+ ASSERT_EQ(instance_norm->num_batches_tracked.dim(), 0);
+
+ ASSERT_TRUE(instance_norm->options.affine());
+
+ ASSERT_TRUE(instance_norm->weight.defined());
+ ASSERT_EQ(instance_norm->weight.dim(), 1);
+ ASSERT_EQ(instance_norm->weight.size(0), 5);
+
+ ASSERT_TRUE(instance_norm->bias.defined());
+ ASSERT_EQ(instance_norm->bias.dim(), 1);
+ ASSERT_EQ(instance_norm->bias.size(0), 5);
+}
+
+TEST_F(ModulesTest, InstanceNorm3dStateless) {
+ InstanceNorm3d instance_norm(InstanceNorm3dOptions(5).track_running_stats(false).affine(false));
+
+ ASSERT_FALSE(instance_norm->running_mean.defined());
+ ASSERT_FALSE(instance_norm->running_var.defined());
+ ASSERT_FALSE(instance_norm->num_batches_tracked.defined());
+ ASSERT_FALSE(instance_norm->weight.defined());
+ ASSERT_FALSE(instance_norm->bias.defined());
+}
+
+TEST_F(ModulesTest, InstanceNorm3d) {
+ InstanceNorm3d instance_norm(5);
+ instance_norm->eval();
+
+ auto input = torch::arange(2. * 5 * 2 * 2 * 2).view({2, 5, 2, 2, 2}).requires_grad_();
+ auto output = instance_norm->forward(input);
+ auto expected = torch::tensor({{{{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}}},
+ {{{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}},
+ {{{-1.5275, -1.0911},
+ {-0.6547, -0.2182}},
+ {{ 0.2182, 0.6547},
+ { 1.0911, 1.5275}}}}});
+ ASSERT_TRUE(output.allclose(expected, 1e-3));
+ auto s = output.sum();
+ s.backward();
+
+ ASSERT_EQ(input.sizes(), input.grad().sizes());
}
TEST_F(ModulesTest, Linear_CUDA) {
@@ -3178,6 +3468,30 @@
"torch::nn::BatchNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
}
+TEST_F(ModulesTest, PrettyPrintInstanceNorm1d) {
+ ASSERT_EQ(
+ c10::str(InstanceNorm1d(
+ InstanceNorm1dOptions(4).eps(0.5).momentum(0.1).affine(false)
+ .track_running_stats(true))),
+ "torch::nn::InstanceNorm1d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintInstanceNorm2d) {
+ ASSERT_EQ(
+ c10::str(InstanceNorm2d(
+ InstanceNorm2dOptions(4).eps(0.5).momentum(0.1).affine(false)
+ .track_running_stats(true))),
+ "torch::nn::InstanceNorm2d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
+TEST_F(ModulesTest, PrettyPrintInstanceNorm3d) {
+ ASSERT_EQ(
+ c10::str(InstanceNorm3d(
+ InstanceNorm3dOptions(4).eps(0.5).momentum(0.1).affine(false)
+ .track_running_stats(true))),
+ "torch::nn::InstanceNorm3d(4, eps=0.5, momentum=0.1, affine=false, track_running_stats=true)");
+}
+
TEST_F(ModulesTest, PrettyPrintLayerNorm) {
ASSERT_EQ(
c10::str(LayerNorm(LayerNormOptions({2, 2}))),
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index 7a80b4c..55d14be 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -75,9 +75,9 @@
torch.nn.BatchNorm3d|Yes|No
torch.nn.GroupNorm|Yes|No
torch.nn.SyncBatchNorm|No|No
-torch.nn.InstanceNorm1d|No|No
-torch.nn.InstanceNorm2d|No|No
-torch.nn.InstanceNorm3d|No|No
+torch.nn.InstanceNorm1d|Yes|No
+torch.nn.InstanceNorm2d|Yes|No
+torch.nn.InstanceNorm3d|Yes|No
torch.nn.LayerNorm|Yes|No
torch.nn.LocalResponseNorm|Yes|No
torch.nn.CrossMapLRN2d|Yes|No
diff --git a/tools/build_variables.py b/tools/build_variables.py
index 80d7dbc..1e132db 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -221,6 +221,7 @@
"torch/csrc/api/src/nn/modules/activation.cpp",
"torch/csrc/api/src/nn/modules/batchnorm.cpp",
"torch/csrc/api/src/nn/modules/normalization.cpp",
+ "torch/csrc/api/src/nn/modules/instancenorm.cpp",
"torch/csrc/api/src/nn/modules/conv.cpp",
"torch/csrc/api/src/nn/modules/dropout.cpp",
"torch/csrc/api/src/nn/modules/distance.cpp",
@@ -239,6 +240,7 @@
"torch/csrc/api/src/nn/options/batchnorm.cpp",
"torch/csrc/api/src/nn/options/conv.cpp",
"torch/csrc/api/src/nn/options/dropout.cpp",
+ "torch/csrc/api/src/nn/options/instancenorm.cpp",
"torch/csrc/api/src/nn/options/linear.cpp",
"torch/csrc/api/src/nn/options/normalization.cpp",
"torch/csrc/api/src/nn/options/embedding.cpp",
diff --git a/torch/csrc/api/include/torch/nn/functional.h b/torch/csrc/api/include/torch/nn/functional.h
index 8fe4af7..0a9db5c 100644
--- a/torch/csrc/api/include/torch/nn/functional.h
+++ b/torch/csrc/api/include/torch/nn/functional.h
@@ -14,3 +14,4 @@
#include <torch/nn/functional/pooling.h>
#include <torch/nn/functional/upsampling.h>
#include <torch/nn/functional/vision.h>
+#include <torch/nn/functional/instancenorm.h>
diff --git a/torch/csrc/api/include/torch/nn/functional/instancenorm.h b/torch/csrc/api/include/torch/nn/functional/instancenorm.h
new file mode 100644
index 0000000..b5534d7
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/functional/instancenorm.h
@@ -0,0 +1,30 @@
+#pragma once
+
+#include <torch/nn/options/instancenorm.h>
+
+namespace torch {
+namespace nn {
+namespace functional {
+
+namespace detail {
+inline Tensor instance_norm(const Tensor& input, const Tensor& running_mean,
+ const Tensor& running_var, const Tensor& weight, const Tensor& bias,
+ bool use_input_stats, double momentum, double eps) {
+
+ return torch::instance_norm(
+ input, weight, bias, running_mean, running_var,
+ use_input_stats, momentum, eps, at::globalContext().userEnabledCuDNN()
+ );
+}
+} // namespace detail
+
+inline Tensor instance_norm(const Tensor& input, const InstanceNormFuncOptions& options = {}) {
+ return detail::instance_norm(
+ input, options.running_mean(),
+ options.running_var(), options.weight(), options.bias(),
+ options.use_input_stats(), options.momentum(), options.eps());
+}
+
+} // namespace functional
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules.h b/torch/csrc/api/include/torch/nn/modules.h
index 2239e02..57a9213 100644
--- a/torch/csrc/api/include/torch/nn/modules.h
+++ b/torch/csrc/api/include/torch/nn/modules.h
@@ -9,6 +9,7 @@
// Layers
#include <torch/nn/modules/batchnorm.h>
+#include <torch/nn/modules/instancenorm.h>
#include <torch/nn/modules/conv.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/modules/distance.h>
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index 46caf6a..9272c42 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -1,12 +1,16 @@
#pragma once
#include <torch/nn/cloneable.h>
+#include <torch/nn/functional/batchnorm.h>
#include <torch/nn/options/batchnorm.h>
+#include <torch/nn/init.h>
#include <torch/nn/pimpl.h>
#include <torch/types.h>
#include <cstdint>
+namespace F = torch::nn::functional;
+
namespace torch {
namespace nn {
@@ -77,28 +81,55 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-/// Base class for all (dimension-specialized) batchnorm modules.
-template <size_t D, typename Derived>
-class TORCH_API BatchNormImplBase : public torch::nn::Cloneable<Derived> {
+/// Base class for all (dimension-specialized) batchnorm and instancenorm modules.
+template <size_t D, typename Derived, typename DerivedOptions>
+class NormImplBase : public torch::nn::Cloneable<Derived> {
protected:
virtual void _check_input_dim(const Tensor& input) = 0;
public:
- explicit BatchNormImplBase(const BatchNormOptions& options_);
+ NormImplBase(const DerivedOptions& options_) : options(options_) {
+ reset();
+ }
- Tensor forward(const Tensor& input);
+ void reset() override {
+ if (options.affine()) {
+ weight = this->register_parameter("weight", torch::empty({options.num_features()}));
+ bias = this->register_parameter("bias", torch::empty({options.num_features()}));
+ } else {
+ weight = this->register_parameter("weight", Tensor());
+ bias = this->register_parameter("bias", Tensor());
+ }
+ if (options.track_running_stats()) {
+ running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
+ running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
+ num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
+ } else {
+ running_mean = this->register_buffer("running_mean", Tensor());
+ running_var = this->register_buffer("running_var", Tensor());
+ num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
+ }
+ reset_parameters();
+ }
- void reset() override;
+ void reset_running_stats() {
+ if (options.track_running_stats()) {
+ running_mean.zero_();
+ running_var.fill_(1);
+ num_batches_tracked.zero_();
+ }
+ }
- void reset_running_stats();
-
- void reset_parameters();
-
- /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
- void pretty_print(std::ostream& stream) const override;
+ void reset_parameters() {
+ reset_running_stats();
+ if (options.affine()) {
+ torch::nn::init::ones_(weight);
+ torch::nn::init::zeros_(bias);
+ }
+ }
/// The options with which this module was constructed.
- BatchNormOptions options;
+ DerivedOptions options;
/// The learned weight.
/// Only defined if the `affine` option was `true` upon construction.
@@ -121,6 +152,47 @@
Tensor num_batches_tracked;
};
+/// Base class for all (dimension-specialized) batchnorm modules.
+template <size_t D, typename Derived>
+class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
+ public:
+ using NormImplBase<D, Derived, BatchNormOptions>::NormImplBase;
+
+ Tensor forward(const Tensor& input) {
+ this->_check_input_dim(input);
+ double exponential_average_factor;
+ if (this->options.momentum() == c10::nullopt) {
+ exponential_average_factor = 0.0;
+ } else {
+ exponential_average_factor = this->options.momentum().value();
+ }
+
+ if (this->is_training() && this->options.track_running_stats()) {
+ if (this->num_batches_tracked.defined()) {
+ this->num_batches_tracked += 1;
+ if (this->options.momentum() == c10::nullopt) { // use cumulative moving average
+ exponential_average_factor = 1.0 / this->num_batches_tracked.template item<double>();
+ } else { // use exponential moving average
+ exponential_average_factor = this->options.momentum().value();
+ }
+ }
+ }
+
+ return F::detail::batch_norm(
+ input,
+ this->running_mean,
+ this->running_var,
+ this->weight,
+ this->bias,
+ this->is_training() || !this->options.track_running_stats(),
+ /*momentum=*/exponential_average_factor,
+ this->options.eps());
+ }
+
+ /// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+};
+
/// Applies the BatchNorm1d function.
/// See https://pytorch.org/docs/master/nn.html#torch.nn.BatchNorm1d to learn
/// about the exact behavior of this module.
diff --git a/torch/csrc/api/include/torch/nn/modules/instancenorm.h b/torch/csrc/api/include/torch/nn/modules/instancenorm.h
new file mode 100644
index 0000000..21164d9
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/modules/instancenorm.h
@@ -0,0 +1,66 @@
+#pragma once
+
+#include <torch/nn/modules/batchnorm.h>
+#include <torch/nn/options/instancenorm.h>
+
+namespace torch {
+namespace nn {
+
+/// Base class for all (dimension-specialized) instance norm modules
+template <size_t D, typename Derived>
+class InstanceNormImpl : public torch::nn::NormImplBase<D, Derived, InstanceNormOptions> {
+ public:
+ using torch::nn::NormImplBase<D, Derived, InstanceNormOptions>::NormImplBase;
+
+ Tensor forward(const Tensor& input) {
+ this->_check_input_dim(input);
+ return F::detail::instance_norm(
+ input, this->running_mean, this->running_var, this->weight, this->bias,
+ this->is_training() || !this->options.track_running_stats(), this->options.momentum(), this->options.eps());
+ }
+
+ /// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
+ void pretty_print(std::ostream& stream) const override;
+};
+
+/// Applies the InstanceNorm1d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm1d to learn
+/// about the exact behavior of this module.
+class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> {
+ protected:
+ virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+ using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl;
+};
+
+TORCH_MODULE(InstanceNorm1d);
+
+/// Applies the InstanceNorm2d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm2d to learn
+/// about the exact behavior of this module.
+class TORCH_API InstanceNorm2dImpl : public InstanceNormImpl<2, InstanceNorm2dImpl> {
+ protected:
+ virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+ using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl;
+};
+
+TORCH_MODULE(InstanceNorm2d);
+
+/// Applies the InstanceNorm3d function.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.InstanceNorm3d to learn
+/// about the exact behavior of this module.
+class TORCH_API InstanceNorm3dImpl : public InstanceNormImpl<3, InstanceNorm3dImpl> {
+ protected:
+ virtual void _check_input_dim(const Tensor& input) override;
+
+ public:
+ using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl;
+};
+
+TORCH_MODULE(InstanceNorm3d);
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/options/batchnorm.h b/torch/csrc/api/include/torch/nn/options/batchnorm.h
index 5495c07..1770191 100644
--- a/torch/csrc/api/include/torch/nn/options/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/options/batchnorm.h
@@ -43,7 +43,7 @@
namespace functional {
-/// Options for the `BatchNorm` module.
+/// Options for the `BatchNorm` functional.
struct TORCH_API BatchNormFuncOptions {
TORCH_ARG(Tensor, weight) = Tensor();
diff --git a/torch/csrc/api/include/torch/nn/options/instancenorm.h b/torch/csrc/api/include/torch/nn/options/instancenorm.h
new file mode 100644
index 0000000..8c536c2
--- /dev/null
+++ b/torch/csrc/api/include/torch/nn/options/instancenorm.h
@@ -0,0 +1,59 @@
+#pragma once
+
+#include <torch/arg.h>
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/nn/options/batchnorm.h>
+#include <torch/types.h>
+
+namespace torch {
+namespace nn {
+
+/// Options for the `InstanceNorm` module.
+struct TORCH_API InstanceNormOptions {
+ /* implicit */ InstanceNormOptions(int64_t num_features);
+
+ /// The number of features of the input tensor.
+ TORCH_ARG(int64_t, num_features);
+
+ /// The epsilon value added for numerical stability.
+ TORCH_ARG(double, eps) = 1e-5;
+
+ /// A momentum multiplier for the mean and variance.
+ TORCH_ARG(double, momentum) = 0.1;
+
+ /// Whether to learn a scale and bias that are applied in an affine
+ /// transformation on the input.
+ TORCH_ARG(bool, affine) = false;
+
+ /// Whether to store and update batch statistics (mean and variance) in the
+ /// module.
+ TORCH_ARG(bool, track_running_stats) = false;
+};
+
+using InstanceNorm1dOptions = InstanceNormOptions;
+using InstanceNorm2dOptions = InstanceNormOptions;
+using InstanceNorm3dOptions = InstanceNormOptions;
+
+namespace functional {
+
+/// Options for the `InstanceNorm` functional.
+struct TORCH_API InstanceNormFuncOptions {
+ TORCH_ARG(Tensor, running_mean) = Tensor();
+
+ TORCH_ARG(Tensor, running_var) = Tensor();
+
+ TORCH_ARG(Tensor, weight) = Tensor();
+
+ TORCH_ARG(Tensor, bias) = Tensor();
+
+ TORCH_ARG(bool, use_input_stats) = true;
+
+ TORCH_ARG(double, momentum) = 0.1;
+
+ TORCH_ARG(double, eps) = 1e-5;
+};
+
+} // namespace functional
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index 2311850..ec05a9c 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -3,7 +3,6 @@
#include <torch/cuda.h>
#include <torch/types.h>
-#include <torch/nn/init.h>
#include <c10/util/Exception.h>
@@ -17,7 +16,7 @@
namespace torch {
namespace nn {
-BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) {
+BatchNormImpl::BatchNormImpl(const BatchNormOptions& options_) : options(options_) { // NOLINT(modernize-pass-by-value)
TORCH_WARN("torch::nn::BatchNorm module is deprecated and will be removed in 1.5. "
"Use BatchNorm{1,2,3}d instead.");
reset();
@@ -78,93 +77,17 @@
torch::cuda::cudnn_is_available());
}
-template <size_t D, typename Derived>
-BatchNormImplBase<D, Derived>::BatchNormImplBase(const BatchNormOptions& options_)
- : options(options_) {
- reset();
-}
+// ===========================================================================
-template <size_t D, typename Derived>
-void BatchNormImplBase<D, Derived>::reset() {
- if (options.affine()) {
- weight = this->register_parameter("weight", torch::empty({options.num_features()}));
- bias = this->register_parameter("bias", torch::empty({options.num_features()}));
- } else {
- weight = this->register_parameter("weight", Tensor());
- bias = this->register_parameter("bias", Tensor());
- }
- if (options.track_running_stats()) {
- running_mean = this->register_buffer("running_mean", torch::zeros({options.num_features()}));
- running_var = this->register_buffer("running_var", torch::ones({options.num_features()}));
- num_batches_tracked = this->register_buffer("num_batches_tracked", torch::tensor(0, torch::dtype(torch::kLong)));
- } else {
- running_mean = this->register_buffer("running_mean", Tensor());
- running_var = this->register_buffer("running_var", Tensor());
- num_batches_tracked = this->register_buffer("num_batches_tracked", Tensor());
- }
- reset_parameters();
-}
-
-template <size_t D, typename Derived>
-void BatchNormImplBase<D, Derived>::reset_running_stats() {
- if (options.track_running_stats()) {
- running_mean.zero_();
- running_var.fill_(1);
- num_batches_tracked.zero_();
- }
-}
-
-template <size_t D, typename Derived>
-void BatchNormImplBase<D, Derived>::reset_parameters() {
- reset_running_stats();
- if (options.affine()) {
- torch::nn::init::ones_(weight);
- torch::nn::init::zeros_(bias);
- }
-}
-
-template <size_t D, typename Derived>
+template <size_t D, typename Derived>
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
stream << std::boolalpha
<< "torch::nn::BatchNorm" << D << "d("
- << options.num_features() << ", "
- << "eps=" << options.eps() << ", "
- << "momentum=" << options.momentum().value() << ", "
- << "affine=" << options.affine() << ", "
- << "track_running_stats=" << options.track_running_stats() << ")";
-}
-
-template <size_t D, typename Derived>
-Tensor BatchNormImplBase<D, Derived>::forward(const Tensor& input) {
- _check_input_dim(input);
-
- double exponential_average_factor;
- if (options.momentum() == c10::nullopt) {
- exponential_average_factor = 0.0;
- } else {
- exponential_average_factor = options.momentum().value();
- }
-
- if (this->is_training() && options.track_running_stats()) {
- if (num_batches_tracked.defined()) {
- num_batches_tracked += 1;
- if (options.momentum() == c10::nullopt) { // use cumulative moving average
- exponential_average_factor = 1.0 / num_batches_tracked.item<double>();
- } else { // use exponential moving average
- exponential_average_factor = options.momentum().value();
- }
- }
- }
-
- return F::detail::batch_norm(
- input,
- running_mean,
- running_var,
- weight,
- bias,
- this->is_training() || !options.track_running_stats(),
- /*momentum=*/exponential_average_factor,
- options.eps());
+ << this->options.num_features() << ", "
+ << "eps=" << this->options.eps() << ", "
+ << "momentum=" << this->options.momentum().value() << ", "
+ << "affine=" << this->options.affine() << ", "
+ << "track_running_stats=" << this->options.track_running_stats() << ")";
}
void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
diff --git a/torch/csrc/api/src/nn/modules/instancenorm.cpp b/torch/csrc/api/src/nn/modules/instancenorm.cpp
new file mode 100644
index 0000000..c400c94
--- /dev/null
+++ b/torch/csrc/api/src/nn/modules/instancenorm.cpp
@@ -0,0 +1,57 @@
+#include <torch/nn/functional/instancenorm.h>
+#include <torch/nn/modules/instancenorm.h>
+
+namespace F = torch::nn::functional;
+
+namespace torch {
+namespace nn {
+
+template <size_t D, typename Derived>
+void InstanceNormImpl<D, Derived>::pretty_print(std::ostream& stream) const {
+ stream << std::boolalpha
+ << "torch::nn::InstanceNorm" << D << "d("
+ << this->options.num_features() << ", "
+ << "eps=" << this->options.eps() << ", "
+ << "momentum=" << this->options.momentum() << ", "
+ << "affine=" << this->options.affine() << ", "
+ << "track_running_stats=" << this->options.track_running_stats() << ")";
+}
+
+void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
+ if (input.dim() == 2) {
+ TORCH_CHECK(
+ false,
+ "InstanceNorm1d returns 0-filled tensor to 2D tensor.",
+ "This is because InstanceNorm1d reshapes inputs to",
+ "(1, N * C, ...) from (N, C,...) and this makes",
+ "variances 0.");
+ }
+ if (input.dim() != 3) {
+ TORCH_CHECK(
+ false,
+ "expected 3D input (got ", input.dim(), "D input)");
+ }
+}
+
+void InstanceNorm2dImpl::_check_input_dim(const Tensor& input) {
+ if (input.dim() != 4) {
+ TORCH_CHECK(
+ false,
+ "expected 4D input (got ", input.dim(), "D input)");
+ }
+}
+
+void InstanceNorm3dImpl::_check_input_dim(const Tensor& input) {
+ if (input.dim() != 5) { // NOLINT(cppcoreguidelines-avoid-magic-numbers)
+ TORCH_CHECK(
+ false,
+ "expected 5D input (got ", input.dim(), "D input)");
+ }
+}
+
+template class InstanceNormImpl<1, InstanceNorm1dImpl>;
+template class InstanceNormImpl<2, InstanceNorm2dImpl>;
+template class InstanceNormImpl<3, InstanceNorm3dImpl>;
+
+} // namespace nn
+} // namespace torch
diff --git a/torch/csrc/api/src/nn/options/instancenorm.cpp b/torch/csrc/api/src/nn/options/instancenorm.cpp
new file mode 100644
index 0000000..930fc27
--- /dev/null
+++ b/torch/csrc/api/src/nn/options/instancenorm.cpp
@@ -0,0 +1,9 @@
+#include <torch/nn/options/instancenorm.h>
+
+namespace torch {
+namespace nn {
+
+InstanceNormOptions::InstanceNormOptions(int64_t num_features) : num_features_(num_features) {}
+
+} // namespace nn
+} // namespace torch