Fix torch::nn::init::orthogonal_ with CNNs (#18915)
Summary:
Fixes #18518
I changed the C++ API torch::nn::init::orthogonal_ implementation to match the Python implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18915
Differential Revision: D14851833
Pulled By: ezyang
fbshipit-source-id: 45b5e9741582777c203e9ebed564ab3ac1f94baf
diff --git a/test/cpp/api/init.cpp b/test/cpp/api/init.cpp
index c4b2f97..5527d72 100644
--- a/test/cpp/api/init.cpp
+++ b/test/cpp/api/init.cpp
@@ -2,6 +2,7 @@
#include <torch/nn/init.h>
#include <torch/nn/modules/linear.h>
+#include <torch/nn/modules/conv.h>
#include <test/cpp/api/init_baseline.h>
#include <test/cpp/api/support.h>
@@ -123,4 +124,9 @@
double gain =
torch::nn::init::calculate_gain(torch::nn::init::Nonlinearity::LeakyReLU);
ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
+}
+
+TEST(InitTest, CanInitializeCnnWithOrthogonal) {
+ torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
+ torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
}
\ No newline at end of file
diff --git a/torch/csrc/api/src/nn/init.cpp b/torch/csrc/api/src/nn/init.cpp
index 187a252..7d64b9f 100644
--- a/torch/csrc/api/src/nn/init.cpp
+++ b/torch/csrc/api/src/nn/init.cpp
@@ -123,7 +123,7 @@
"Only tensors with 2 or more dimensions are supported");
const auto rows = tensor.size(0);
- const auto columns = tensor.size(1);
+ const auto columns = tensor.numel() / rows;
auto flattened = torch::randn({rows, columns});
if (rows < columns) {
diff --git a/torch/nn/init.py b/torch/nn/init.py
index 731cd72..583053e 100644
--- a/torch/nn/init.py
+++ b/torch/nn/init.py
@@ -345,7 +345,7 @@
raise ValueError("Only tensors with 2 or more dimensions are supported")
rows = tensor.size(0)
- cols = tensor[0].numel()
+ cols = tensor.numel() // rows
flattened = tensor.new(rows, cols).normal_(0, 1)
if rows < cols: