[C++ API] RNN / GRU / LSTM layer refactoring (#34322)
Summary:
This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API.
**BC-breaking changes:**
- Instead of returning `RNNOutput`, RNN / GRU forward method now returns `std::tuple<Tensor, Tensor>`, and LSTM forward method now returns `std::tuple<Tensor, std::tuple<Tensor, Tensor>>`, matching Python API.
- RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API.
- RNN / LSTM / GRU now has `forward_with_packed_input` method which accepts `PackedSequence` as input and optionally hidden state, matching the `forward(PackedSequence, ...)` variant in Python API.
- In `RNNOptions`
- `tanh()` / `relu()` / `activation` are removed. Instead, `nonlinearity` is added which takes either `torch::kTanh` or `torch::kReLU`
- `layers` -> `num_layers`
- `with_bias` -> `bias`
- In `LSTMOptions`
- `layers` -> `num_layers`
- `with_bias` -> `bias`
- In `GRUOptions`
- `layers` -> `num_layers`
- `with_bias` -> `bias`
The majority of the changes in this PR focused on refactoring the implementations in `torch/csrc/api/src/nn/modules/rnn.cpp` to match the Python API. RNN tests are then changed to reflected the revised API design.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34322
Differential Revision: D20311699
Pulled By: yf225
fbshipit-source-id: e2b60fc7bac64367a8434647d74c08568a7b28f7
diff --git a/test/cpp/api/enum.cpp b/test/cpp/api/enum.cpp
index 12a8879..ac20edb 100644
--- a/test/cpp/api/enum.cpp
+++ b/test/cpp/api/enum.cpp
@@ -43,7 +43,11 @@
torch::enumtype::kBatchMean,
torch::enumtype::kZeros,
torch::enumtype::kBorder,
- torch::enumtype::kReflection
+ torch::enumtype::kReflection,
+ torch::enumtype::kRNN_TANH,
+ torch::enumtype::kRNN_RELU,
+ torch::enumtype::kLSTM,
+ torch::enumtype::kGRU
> v;
TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
@@ -76,4 +80,8 @@
TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
TORCH_ENUM_PRETTY_PRINT_TEST(Border)
TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
+ TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
+ TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
+ TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
+ TORCH_ENUM_PRETTY_PRINT_TEST(GRU)
}
diff --git a/test/cpp/api/modulelist.cpp b/test/cpp/api/modulelist.cpp
index 3a688d0..5aa4ccb 100644
--- a/test/cpp/api/modulelist.cpp
+++ b/test/cpp/api/modulelist.cpp
@@ -283,7 +283,7 @@
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
- " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+ " (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
")");
}
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index c17a0ab..c51c386 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -23,7 +23,7 @@
auto B = x.size(1);
x = x.view({T * B, 1});
x = l1->forward(x).view({T, B, nhid}).tanh_();
- x = rnn->forward(x).output[T - 1];
+ x = std::get<0>(rnn->forward(x))[T - 1];
x = lo->forward(x);
return x;
};
@@ -61,29 +61,39 @@
return true;
};
-void check_lstm_sizes(RNNOutput output) {
+void check_lstm_sizes(std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output) {
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
// 10 and 16 time steps (10 x 16 x n)
- ASSERT_EQ(output.output.ndimension(), 3);
- ASSERT_EQ(output.output.size(0), 10);
- ASSERT_EQ(output.output.size(1), 16);
- ASSERT_EQ(output.output.size(2), 64);
+ torch::Tensor output = std::get<0>(lstm_output);
+ std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
+ torch::Tensor hx = std::get<0>(state);
+ torch::Tensor cx = std::get<1>(state);
- ASSERT_EQ(output.state.ndimension(), 4);
- ASSERT_EQ(output.state.size(0), 2); // (hx, cx)
- ASSERT_EQ(output.state.size(1), 3); // layers
- ASSERT_EQ(output.state.size(2), 16); // Batchsize
- ASSERT_EQ(output.state.size(3), 64); // 64 hidden dims
+ ASSERT_EQ(output.ndimension(), 3);
+ ASSERT_EQ(output.size(0), 10);
+ ASSERT_EQ(output.size(1), 16);
+ ASSERT_EQ(output.size(2), 64);
+
+ ASSERT_EQ(hx.ndimension(), 3);
+ ASSERT_EQ(hx.size(0), 3); // layers
+ ASSERT_EQ(hx.size(1), 16); // Batchsize
+ ASSERT_EQ(hx.size(2), 64); // 64 hidden dims
+
+ ASSERT_EQ(cx.ndimension(), 3);
+ ASSERT_EQ(cx.size(0), 3); // layers
+ ASSERT_EQ(cx.size(1), 16); // Batchsize
+ ASSERT_EQ(cx.size(2), 64); // 64 hidden dims
// Something is in the hiddens
- ASSERT_GT(output.state.norm().item<float>(), 0);
+ ASSERT_GT(hx.norm().item<float>(), 0);
+ ASSERT_GT(cx.norm().item<float>(), 0);
}
struct RNNTest : torch::test::SeedingFixture {};
TEST_F(RNNTest, CheckOutputSizes) {
- LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
+ LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
// Input size is: sequence length, batch size, input size
auto x = torch::randn({10, 16, 128}, torch::requires_grad());
auto output = model->forward(x);
@@ -92,11 +102,17 @@
y.backward();
check_lstm_sizes(output);
- auto next = model->forward(x, output.state);
+ auto next = model->forward(x, std::get<1>(output));
check_lstm_sizes(next);
- torch::Tensor diff = next.state - output.state;
+ auto output_hx = std::get<0>(std::get<1>(output));
+ auto output_cx = std::get<1>(std::get<1>(output));
+
+ auto next_hx = std::get<0>(std::get<1>(next));
+ auto next_cx = std::get<1>(std::get<1>(next));
+
+ torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
// Hiddens changed
ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
@@ -122,12 +138,12 @@
}
auto out = model->forward(x);
- ASSERT_EQ(out.output.ndimension(), 3);
- ASSERT_EQ(out.output.size(0), 3);
- ASSERT_EQ(out.output.size(1), 4);
- ASSERT_EQ(out.output.size(2), 2);
+ ASSERT_EQ(std::get<0>(out).ndimension(), 3);
+ ASSERT_EQ(std::get<0>(out).size(0), 3);
+ ASSERT_EQ(std::get<0>(out).size(1), 4);
+ ASSERT_EQ(std::get<0>(out).size(2), 2);
- auto flat = out.output.view(3 * 4 * 2);
+ auto flat = std::get<0>(out).view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
@@ -136,12 +152,20 @@
ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
}
- ASSERT_EQ(out.state.ndimension(), 4); // (hx, cx) x layers x B x 2
- ASSERT_EQ(out.state.size(0), 2);
- ASSERT_EQ(out.state.size(1), 1);
- ASSERT_EQ(out.state.size(2), 4);
- ASSERT_EQ(out.state.size(3), 2);
- flat = out.state.view(16);
+ auto hx = std::get<0>(std::get<1>(out));
+ auto cx = std::get<1>(std::get<1>(out));
+
+ ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2
+ ASSERT_EQ(hx.size(0), 1);
+ ASSERT_EQ(hx.size(1), 4);
+ ASSERT_EQ(hx.size(2), 2);
+
+ ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2
+ ASSERT_EQ(cx.size(0), 1);
+ ASSERT_EQ(cx.size(1), 4);
+ ASSERT_EQ(cx.size(2), 2);
+
+ flat = torch::cat({hx, cx}, 0).view(16);
float h_out[] = {0.7889,
0.9003,
0.7769,
@@ -165,27 +189,27 @@
TEST_F(RNNTest, EndToEndLSTM) {
ASSERT_TRUE(test_RNN_xor<LSTM>(
- [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
+ [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }));
}
TEST_F(RNNTest, EndToEndGRU) {
ASSERT_TRUE(
- test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
+ test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }));
}
TEST_F(RNNTest, EndToEndRNNRelu) {
ASSERT_TRUE(test_RNN_xor<RNN>(
- [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
+ [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }));
}
TEST_F(RNNTest, EndToEndRNNTanh) {
ASSERT_TRUE(test_RNN_xor<RNN>(
- [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
+ [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }));
}
TEST_F(RNNTest, Sizes_CUDA) {
torch::manual_seed(0);
- LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
+ LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
model->to(torch::kCUDA);
auto x =
torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
@@ -195,11 +219,17 @@
y.backward();
check_lstm_sizes(output);
- auto next = model->forward(x, output.state);
+ auto next = model->forward(x, std::get<1>(output));
check_lstm_sizes(next);
- torch::Tensor diff = next.state - output.state;
+ auto output_hx = std::get<0>(std::get<1>(output));
+ auto output_cx = std::get<1>(std::get<1>(output));
+
+ auto next_hx = std::get<0>(std::get<1>(next));
+ auto next_cx = std::get<1>(std::get<1>(next));
+
+ torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
// Hiddens changed
ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
@@ -207,51 +237,68 @@
TEST_F(RNNTest, EndToEndLSTM_CUDA) {
ASSERT_TRUE(test_RNN_xor<LSTM>(
- [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
+ [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true));
}
TEST_F(RNNTest, EndToEndGRU_CUDA) {
ASSERT_TRUE(test_RNN_xor<GRU>(
- [](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
+ [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true));
}
TEST_F(RNNTest, EndToEndRNNRelu_CUDA) {
ASSERT_TRUE(test_RNN_xor<RNN>(
- [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
+ [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }, true));
}
TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
ASSERT_TRUE(test_RNN_xor<RNN>(
- [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
+ [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }, true));
}
TEST_F(RNNTest, PrettyPrintRNNs) {
ASSERT_EQ(
- c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
- "torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
+ c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))),
+ "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
ASSERT_EQ(
- c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
- "torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
+ c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))),
+ "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)");
ASSERT_EQ(
- c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
- "torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
+ c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh))),
+ "torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
}
// This test assures that flatten_parameters does not crash,
// when bidirectional is set to true
// https://github.com/pytorch/pytorch/issues/19545
TEST_F(RNNTest, BidirectionalFlattenParameters) {
- GRU gru(GRUOptions(100, 256).layers(2).bidirectional(true));
+ GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true));
gru->flatten_parameters();
}
template <typename Impl>
-void copyParameters(torch::nn::ModuleHolder<Impl>& target, size_t t_i,
- const torch::nn::ModuleHolder<Impl>& source, size_t s_i) {
+void copyParameters(torch::nn::ModuleHolder<Impl>& target, std::string t_suffix,
+ const torch::nn::ModuleHolder<Impl>& source, std::string s_suffix) {
at::NoGradGuard guard;
- target->w_ih[t_i].copy_(source->w_ih[s_i]);
- target->w_hh[t_i].copy_(source->w_hh[s_i]);
- target->b_ih[t_i].copy_(source->b_ih[s_i]);
- target->b_hh[t_i].copy_(source->b_hh[s_i]);
+ target->named_parameters()["weight_ih_l" + t_suffix].copy_(source->named_parameters()["weight_ih_l" + s_suffix]);
+ target->named_parameters()["weight_hh_l" + t_suffix].copy_(source->named_parameters()["weight_hh_l" + s_suffix]);
+ target->named_parameters()["bias_ih_l" + t_suffix].copy_(source->named_parameters()["bias_ih_l" + s_suffix]);
+ target->named_parameters()["bias_hh_l" + t_suffix].copy_(source->named_parameters()["bias_hh_l" + s_suffix]);
+}
+
+std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device(
+ std::tuple<torch::Tensor, torch::Tensor> gru_output, torch::Device device) {
+ return std::make_tuple(
+ std::get<0>(gru_output).to(device),
+ std::get<1>(gru_output).to(device));
+}
+
+std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output_to_device(
+ std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output, torch::Device device) {
+ auto hidden_states = std::get<1>(lstm_output);
+ return std::make_tuple(
+ std::get<0>(lstm_output).to(device),
+ std::make_tuple(
+ std::get<0>(hidden_states).to(device),
+ std::get<1>(hidden_states).to(device)));
}
// This test is a port of python code introduced here:
@@ -264,7 +311,7 @@
auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
- auto gru_options = GRUOptions(1, 1).layers(1).batch_first(false);
+ auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false);
GRU bi_grus {gru_options.bidirectional(true)};
GRU reverse_gru {gru_options.bidirectional(false)};
@@ -275,28 +322,26 @@
// Now make sure the weights of the reverse gru layer match
// ones of the (reversed) bidirectional's:
- copyParameters(reverse_gru, 0, bi_grus, 1);
+ copyParameters(reverse_gru, "0", bi_grus, "0_reverse");
auto bi_output = bi_grus->forward(input);
auto reverse_output = reverse_gru->forward(input_reversed);
if (cuda) {
- bi_output.output = bi_output.output.to(torch::kCPU);
- bi_output.state = bi_output.state.to(torch::kCPU);
- reverse_output.output = reverse_output.output.to(torch::kCPU);
- reverse_output.state = reverse_output.state.to(torch::kCPU);
+ bi_output = gru_output_to_device(bi_output, torch::kCPU);
+ reverse_output = gru_output_to_device(reverse_output, torch::kCPU);
}
- ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
- auto size = bi_output.output.size(0);
+ ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
+ auto size = std::get<0>(bi_output).size(0);
for (int i = 0; i < size; i++) {
- ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
- reverse_output.output[size - 1 - i][0][0].item<float>());
+ ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
+ std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
}
// The hidden states of the reversed GRUs sits
// in the odd indices in the first dimension.
- ASSERT_EQ(bi_output.state[1][0][0].item<float>(),
- reverse_output.state[0][0][0].item<float>());
+ ASSERT_EQ(std::get<1>(bi_output)[1][0][0].item<float>(),
+ std::get<1>(reverse_output)[0][0][0].item<float>());
}
TEST_F(RNNTest, BidirectionalGRUReverseForward) {
@@ -315,7 +360,7 @@
auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
- auto lstm_opt = GRUOptions(1, 1).layers(1).batch_first(false);
+ auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false);
LSTM bi_lstm {lstm_opt.bidirectional(true)};
LSTM reverse_lstm {lstm_opt.bidirectional(false)};
@@ -327,30 +372,28 @@
// Now make sure the weights of the reverse lstm layer match
// ones of the (reversed) bidirectional's:
- copyParameters(reverse_lstm, 0, bi_lstm, 1);
+ copyParameters(reverse_lstm, "0", bi_lstm, "0_reverse");
auto bi_output = bi_lstm->forward(input);
auto reverse_output = reverse_lstm->forward(input_reversed);
if (cuda) {
- bi_output.output = bi_output.output.to(torch::kCPU);
- bi_output.state = bi_output.state.to(torch::kCPU);
- reverse_output.output = reverse_output.output.to(torch::kCPU);
- reverse_output.state = reverse_output.state.to(torch::kCPU);
+ bi_output = lstm_output_to_device(bi_output, torch::kCPU);
+ reverse_output = lstm_output_to_device(reverse_output, torch::kCPU);
}
- ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
- auto size = bi_output.output.size(0);
+ ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
+ auto size = std::get<0>(bi_output).size(0);
for (int i = 0; i < size; i++) {
- ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
- reverse_output.output[size - 1 - i][0][0].item<float>());
+ ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
+ std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
}
// The hidden states of the reversed LSTM sits
// in the odd indices in the first dimension.
- ASSERT_EQ(bi_output.state[0][1][0][0].item<float>(),
- reverse_output.state[0][0][0][0].item<float>());
- ASSERT_EQ(bi_output.state[1][1][0][0].item<float>(),
- reverse_output.state[1][0][0][0].item<float>());
+ ASSERT_EQ(std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(),
+ std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>());
+ ASSERT_EQ(std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(),
+ std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>());
}
TEST_F(RNNTest, BidirectionalLSTMReverseForward) {
@@ -363,19 +406,15 @@
TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
// Create two GRUs with the same options
- auto opt = GRUOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
+ auto opt = GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
GRU gru_cpu {opt};
GRU gru_cuda {opt};
// Copy weights and biases from CPU GRU to CUDA GRU
{
at::NoGradGuard guard;
- const auto num_directions = gru_cpu->options.bidirectional() ? 2 : 1;
- for (int64_t layer = 0; layer < gru_cpu->options.layers(); layer++) {
- for (auto direction = 0; direction < num_directions; direction++) {
- const auto layer_idx = (layer * num_directions) + direction;
- copyParameters(gru_cuda, layer_idx, gru_cpu, layer_idx);
- }
+ for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) {
+ gru_cuda->named_parameters()[param.key()].copy_(gru_cpu->named_parameters()[param.key()]);
}
}
@@ -397,20 +436,19 @@
auto output_cpu = gru_cpu->forward(input_cpu);
auto output_cuda = gru_cuda->forward(input_cuda);
- output_cpu.output = output_cpu.output.to(torch::kCPU);
- output_cpu.state = output_cpu.state.to(torch::kCPU);
+ output_cpu = gru_output_to_device(output_cpu, torch::kCPU);
// Assert that the output and state are equal on CPU and CUDA
- ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
- for (int i = 0; i < output_cpu.output.dim(); i++) {
- ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
+ ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
+ for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
+ ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
}
- for (int i = 0; i < output_cpu.output.size(0); i++) {
- for (int j = 0; j < output_cpu.output.size(1); j++) {
- for (int k = 0; k < output_cpu.output.size(2); k++) {
+ for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
+ for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
+ for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
ASSERT_NEAR(
- output_cpu.output[i][j][k].item<float>(),
- output_cuda.output[i][j][k].item<float>(), 1e-5);
+ std::get<0>(output_cpu)[i][j][k].item<float>(),
+ std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
}
}
}
@@ -418,19 +456,15 @@
TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
// Create two LSTMs with the same options
- auto opt = LSTMOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
+ auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
LSTM lstm_cpu {opt};
LSTM lstm_cuda {opt};
// Copy weights and biases from CPU LSTM to CUDA LSTM
{
at::NoGradGuard guard;
- const auto num_directions = lstm_cpu->options.bidirectional() ? 2 : 1;
- for (int64_t layer = 0; layer < lstm_cpu->options.layers(); layer++) {
- for (auto direction = 0; direction < num_directions; direction++) {
- const auto layer_idx = (layer * num_directions) + direction;
- copyParameters(lstm_cuda, layer_idx, lstm_cpu, layer_idx);
- }
+ for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
+ lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]);
}
}
@@ -451,21 +485,68 @@
auto output_cpu = lstm_cpu->forward(input_cpu);
auto output_cuda = lstm_cuda->forward(input_cuda);
- output_cpu.output = output_cpu.output.to(torch::kCPU);
- output_cpu.state = output_cpu.state.to(torch::kCPU);
+ output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
// Assert that the output and state are equal on CPU and CUDA
- ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
- for (int i = 0; i < output_cpu.output.dim(); i++) {
- ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
+ ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
+ for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
+ ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
}
- for (int i = 0; i < output_cpu.output.size(0); i++) {
- for (int j = 0; j < output_cpu.output.size(1); j++) {
- for (int k = 0; k < output_cpu.output.size(2); k++) {
+ for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
+ for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
+ for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
ASSERT_NEAR(
- output_cpu.output[i][j][k].item<float>(),
- output_cuda.output[i][j][k].item<float>(), 1e-5);
+ std::get<0>(output_cpu)[i][j][k].item<float>(),
+ std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
}
}
}
}
+
+TEST_F(RNNTest, UsePackedSequenceAsInput) {
+ {
+ torch::manual_seed(0);
+ auto m = RNN(2, 3);
+ torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
+ auto rnn_output = m->forward_with_packed_input(packed_input);
+ auto expected_output = torch::tensor(
+ {{-0.0645, -0.7274, 0.4531},
+ {-0.3970, -0.6950, 0.6009},
+ {-0.3877, -0.7310, 0.6806}});
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+
+ // Test passing optional argument to `RNN::forward_with_packed_input`
+ rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+ }
+ {
+ torch::manual_seed(0);
+ auto m = LSTM(2, 3);
+ torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
+ auto rnn_output = m->forward_with_packed_input(packed_input);
+ auto expected_output = torch::tensor(
+ {{-0.2693, -0.1240, 0.0744},
+ {-0.3889, -0.1919, 0.1183},
+ {-0.4425, -0.2314, 0.1386}});
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+
+ // Test passing optional argument to `LSTM::forward_with_packed_input`
+ rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+ }
+ {
+ torch::manual_seed(0);
+ auto m = GRU(2, 3);
+ torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
+ auto rnn_output = m->forward_with_packed_input(packed_input);
+ auto expected_output = torch::tensor(
+ {{-0.1134, 0.0467, 0.2336},
+ {-0.1189, 0.0502, 0.2960},
+ {-0.1138, 0.0484, 0.3110}});
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+
+ // Test passing optional argument to `GRU::forward_with_packed_input`
+ rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+ }
+}
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index 9107781..2bd3fd9 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -410,7 +410,7 @@
" (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
- " (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+ " (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
")");
Sequential sequential_named({
@@ -429,7 +429,7 @@
" (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
" (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
" (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
- " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+ " (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
")");
}
@@ -598,21 +598,21 @@
torch::manual_seed(0);
Sequential sequential(Identity(), RNN(2, 3));
auto x = torch::ones({2, 3, 2});
- auto rnn_output = sequential->forward<RNNOutput>(x);
+ auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
auto expected_output = torch::tensor(
- {{{0.0000, 0.0000, 0.4886},
- {0.0000, 0.0000, 0.4886},
- {0.0000, 0.0000, 0.4886}},
- {{0.0000, 0.0000, 0.3723},
- {0.0000, 0.0000, 0.3723},
- {0.0000, 0.0000, 0.3723}}});
- ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
+ {{{-0.0645, -0.7274, 0.4531},
+ {-0.0645, -0.7274, 0.4531},
+ {-0.0645, -0.7274, 0.4531}},
+ {{-0.3970, -0.6950, 0.6009},
+ {-0.3970, -0.6950, 0.6009},
+ {-0.3970, -0.6950, 0.6009}}});
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), LSTM(2, 3));
auto x = torch::ones({2, 3, 2});
- auto rnn_output = sequential->forward<RNNOutput>(x);
+ auto rnn_output = sequential->forward<std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
auto expected_output = torch::tensor(
{{{-0.2693, -0.1240, 0.0744},
{-0.2693, -0.1240, 0.0744},
@@ -620,13 +620,13 @@
{{-0.3889, -0.1919, 0.1183},
{-0.3889, -0.1919, 0.1183},
{-0.3889, -0.1919, 0.1183}}});
- ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
Sequential sequential(Identity(), GRU(2, 3));
auto x = torch::ones({2, 3, 2});
- auto rnn_output = sequential->forward<RNNOutput>(x);
+ auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
auto expected_output = torch::tensor(
{{{-0.1134, 0.0467, 0.2336},
{-0.1134, 0.0467, 0.2336},
@@ -634,7 +634,7 @@
{{-0.1189, 0.0502, 0.2960},
{-0.1189, 0.0502, 0.2960},
{-0.1189, 0.0502, 0.2960}}});
- ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
+ ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
}
{
torch::manual_seed(0);
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index e9940d8..0853b96 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -81,9 +81,9 @@
torch::nn::LayerNorm|Yes|No
torch::nn::LocalResponseNorm|Yes|No
torch::nn::CrossMapLRN2d|Yes|No
-torch::nn::RNN|No|No
-torch::nn::LSTM|No|No
-torch::nn::GRU|No|No
+torch::nn::RNN|Yes|No
+torch::nn::LSTM|Yes|No
+torch::nn::GRU|Yes|No
torch::nn::RNNCell|Yes|No
torch::nn::LSTMCell|Yes|No
torch::nn::GRUCell|Yes|No
diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h
index fdd1c3c..2f28530 100644
--- a/torch/csrc/api/include/torch/enum.h
+++ b/torch/csrc/api/include/torch/enum.h
@@ -122,6 +122,10 @@
TORCH_ENUM_DECLARE(Zeros)
TORCH_ENUM_DECLARE(Border)
TORCH_ENUM_DECLARE(Reflection)
+TORCH_ENUM_DECLARE(RNN_TANH)
+TORCH_ENUM_DECLARE(RNN_RELU)
+TORCH_ENUM_DECLARE(LSTM)
+TORCH_ENUM_DECLARE(GRU)
namespace torch {
namespace enumtype {
@@ -157,6 +161,10 @@
TORCH_ENUM_PRETTY_PRINT(Zeros)
TORCH_ENUM_PRETTY_PRINT(Border)
TORCH_ENUM_PRETTY_PRINT(Reflection)
+ TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
+ TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
+ TORCH_ENUM_PRETTY_PRINT(LSTM)
+ TORCH_ENUM_PRETTY_PRINT(GRU)
};
template <typename V>
diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h
index d065faf..d324453 100644
--- a/torch/csrc/api/include/torch/nn/modules/rnn.h
+++ b/torch/csrc/api/include/torch/nn/modules/rnn.h
@@ -5,6 +5,7 @@
#include <torch/nn/modules/common.h>
#include <torch/nn/modules/dropout.h>
#include <torch/nn/pimpl.h>
+#include <torch/nn/utils/rnn.h>
#include <torch/types.h>
#include <ATen/ATen.h>
@@ -15,36 +16,23 @@
#include <memory>
#include <vector>
+using namespace torch::nn::utils::rnn;
+
namespace torch {
namespace nn {
-/// The output of a single invocation of an RNN module's `forward()` method.
-struct TORCH_API RNNOutput {
- /// The result of applying the specific RNN algorithm
- /// to the input tensor and input state.
- Tensor output;
- /// The new, updated state that can be fed into the RNN
- /// in the next forward step.
- Tensor state;
-};
-
namespace detail {
/// Base class for all RNN implementations (intended for code sharing).
template <typename Derived>
class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
public:
- /// These must line up with the CUDNN mode codes:
- /// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
- enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
-
- explicit RNNImplBase(
- const RNNOptionsBase& options_,
- optional<CuDNNMode> cudnn_mode = nullopt,
- int64_t number_of_gates = 1);
+ explicit RNNImplBase(const RNNOptionsBase& options_);
/// Initializes the parameters of the RNN module.
void reset() override;
+ void reset_parameters();
+
/// Overrides `nn::Module::to()` to call `flatten_parameters()` after the
/// original operation.
void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
@@ -65,52 +53,32 @@
/// called once upon construction, inside `reset()`.
void flatten_parameters();
- /// The RNN's options.
- RNNOptionsBase options;
+ std::vector<Tensor> all_weights() const;
- /// The weights for `input x hidden` gates.
- std::vector<Tensor> w_ih;
- /// The weights for `hidden x hidden` gates.
- std::vector<Tensor> w_hh;
- /// The biases for `input x hidden` gates.
- std::vector<Tensor> b_ih;
- /// The biases for `hidden x hidden` gates.
- std::vector<Tensor> b_hh;
+ /// The RNN's options.
+ RNNOptionsBase options_base;
protected:
- /// The function signature of `rnn_relu`, `rnn_tanh` and `gru`.
- using RNNFunctionSignature = std::tuple<Tensor, Tensor>(
- /*input=*/const Tensor&,
- /*state=*/const Tensor&,
- /*params=*/TensorList,
- /*has_biases=*/bool,
- /*layers=*/int64_t,
- /*dropout=*/double,
- /*train=*/bool,
- /*bidirectional=*/bool,
- /*batch_first=*/bool);
+ // Resets flat_weights_
+ // Note: be v. careful before removing this, as 3rd party device types
+ // likely rely on this behavior to properly .to() modules like LSTM.
+ void reset_flat_weights();
- /// A generic `forward()` used for RNN and GRU (but not LSTM!). Takes the ATen
- /// RNN function as first argument.
- RNNOutput generic_forward(
- std::function<RNNFunctionSignature> function,
- const Tensor& input,
- Tensor state);
+ void check_input(const Tensor& input, const Tensor& batch_sizes) const;
- /// Returns a flat vector of all weights, with layer weights following each
- /// other sequentially in (w_ih, w_hh, b_ih, b_hh) order.
- std::vector<Tensor> flat_weights() const;
+ std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes) const;
- /// Very simple check if any of the parameters (weights, biases) are the same.
- bool any_parameters_alias() const;
+ void check_hidden_size(
+ const Tensor& hx,
+ std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
+ std::string msg = "Expected hidden size {1}, got {2}") const;
- /// The number of gate weights/biases required by the RNN subclass.
- int64_t number_of_gates_;
+ void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const;
- /// The cuDNN RNN mode, if this RNN subclass has any.
- optional<CuDNNMode> cudnn_mode_;
+ Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;
- /// The cached result of the latest `flat_weights()` call.
+ std::vector<std::string> flat_weights_names_;
+ std::vector<std::vector<std::string>> all_weights_;
std::vector<Tensor> flat_weights_;
};
} // namespace detail
@@ -118,84 +86,139 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A multi-layer Elman RNN module with Tanh or ReLU activation.
-/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn about the
-/// exact behavior of this module.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::RNNOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
+/// ```
class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
public:
RNNImpl(int64_t input_size, int64_t hidden_size)
: RNNImpl(RNNOptions(input_size, hidden_size)) {}
explicit RNNImpl(const RNNOptions& options_);
- /// Pretty prints the `RNN` module into the given `stream`.
- void pretty_print(std::ostream& stream) const override;
-
- /// Applies the `RNN` module to an input sequence and input state.
- /// The `input` should follow a `(sequence, batch, features)` layout unless
- /// `batch_first` is true, in which case the layout should be `(batch,
- /// sequence, features)`.
- RNNOutput forward(const Tensor& input, Tensor state = {});
+ std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
protected:
FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
+
public:
+ std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
+
RNNOptions options;
+
+ protected:
+ std::tuple<Tensor, Tensor> forward_helper(
+ const Tensor& input,
+ const Tensor& batch_sizes,
+ const Tensor& sorted_indices,
+ int64_t max_batch_size,
+ Tensor hx);
};
/// A `ModuleHolder` subclass for `RNNImpl`.
-/// See the documentation for `RNNImpl` class to learn what methods it provides,
-/// or the documentation for `ModuleHolder` to learn about PyTorch's module
-/// storage semantics.
+/// See the documentation for `RNNImpl` class to learn what methods it
+/// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
TORCH_MODULE(RNN);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A multi-layer long-short-term-memory (LSTM) module.
-/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn about the
-/// exact behavior of this module.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::LSTMOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
public:
LSTMImpl(int64_t input_size, int64_t hidden_size)
: LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
explicit LSTMImpl(const LSTMOptions& options_);
- /// Applies the `LSTM` module to an input sequence and input state.
- /// The `input` should follow a `(sequence, batch, features)` layout unless
- /// `batch_first` is true, in which case the layout should be `(batch,
- /// sequence, features)`.
- RNNOutput forward(const Tensor& input, Tensor state = {});
+ std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
+ const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
protected:
- FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
+ FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
+
+ public:
+ std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> forward_with_packed_input(
+ const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
+
+ LSTMOptions options;
+
+ protected:
+ void check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const;
+
+ std::tuple<Tensor, Tensor> permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const;
+
+ std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
+ const Tensor& input,
+ const Tensor& batch_sizes,
+ const Tensor& sorted_indices,
+ int64_t max_batch_size,
+ torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
};
/// A `ModuleHolder` subclass for `LSTMImpl`.
/// See the documentation for `LSTMImpl` class to learn what methods it
-/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
+/// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
/// module storage semantics.
TORCH_MODULE(LSTM);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
/// A multi-layer gated recurrent unit (GRU) module.
-/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the
-/// exact behavior of this module.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::GRUOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
public:
GRUImpl(int64_t input_size, int64_t hidden_size)
: GRUImpl(GRUOptions(input_size, hidden_size)) {}
explicit GRUImpl(const GRUOptions& options_);
- /// Applies the `GRU` module to an input sequence and input state.
- /// The `input` should follow a `(sequence, batch, features)` layout unless
- /// `batch_first` is true, in which case the layout should be `(batch,
- /// sequence, features)`.
- RNNOutput forward(const Tensor& input, Tensor state = {});
+ std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
protected:
- FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
+ FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})
+
+ public:
+ std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
+
+ GRUOptions options;
+
+ protected:
+ std::tuple<Tensor, Tensor> forward_helper(
+ const Tensor& input,
+ const Tensor& batch_sizes,
+ const Tensor& sorted_indices,
+ int64_t max_batch_size,
+ Tensor hx);
};
/// A `ModuleHolder` subclass for `GRUImpl`.
-/// See the documentation for `GRUImpl` class to learn what methods it provides,
-/// or the documentation for `ModuleHolder` to learn about PyTorch's module
-/// storage semantics.
+/// See the documentation for `GRUImpl` class to learn what methods it
+/// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
TORCH_MODULE(GRU);
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/torch/csrc/api/include/torch/nn/options/rnn.h b/torch/csrc/api/include/torch/nn/options/rnn.h
index f500a90..ae37693 100644
--- a/torch/csrc/api/include/torch/nn/options/rnn.h
+++ b/torch/csrc/api/include/torch/nn/options/rnn.h
@@ -10,65 +10,137 @@
namespace detail {
-/// Common options for LSTM and GRU modules.
+/// Common options for RNN, LSTM and GRU modules.
struct TORCH_API RNNOptionsBase {
- RNNOptionsBase(int64_t input_size, int64_t hidden_size);
- virtual ~RNNOptionsBase() = default;
+ typedef c10::variant<
+ enumtype::kLSTM,
+ enumtype::kGRU,
+ enumtype::kRNN_TANH,
+ enumtype::kRNN_RELU> rnn_options_base_mode_t;
+
+ RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size);
+
+ TORCH_ARG(rnn_options_base_mode_t, mode);
/// The number of features of a single sample in the input sequence `x`.
TORCH_ARG(int64_t, input_size);
/// The number of features in the hidden state `h`.
TORCH_ARG(int64_t, hidden_size);
/// The number of recurrent layers (cells) to use.
- TORCH_ARG(int64_t, layers) = 1;
+ TORCH_ARG(int64_t, num_layers) = 1;
/// Whether a bias term should be added to all linear operations.
- TORCH_ARG(bool, with_bias) = true;
+ TORCH_ARG(bool, bias) = true;
+ /// If true, the input sequence should be provided as `(batch, sequence,
+ /// features)`. If false (default), the expected layout is `(sequence, batch,
+ /// features)`.
+ TORCH_ARG(bool, batch_first) = false;
/// If non-zero, adds dropout with the given probability to the output of each
/// RNN layer, except the final layer.
TORCH_ARG(double, dropout) = 0.0;
/// Whether to make the RNN bidirectional.
TORCH_ARG(bool, bidirectional) = false;
- /// If true, the input sequence should be provided as `(batch, sequence,
- /// features)`. If false (default), the expected layout is `(sequence, batch,
- /// features)`.
- TORCH_ARG(bool, batch_first) = false;
};
} // namespace detail
-enum class RNNActivation : uint32_t {ReLU, Tanh};
-
-/// Options for RNN modules.
+/// Options for the `RNN` module.
+///
+/// Example:
+/// ```
+/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
+/// ```
struct TORCH_API RNNOptions {
+ typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
+
RNNOptions(int64_t input_size, int64_t hidden_size);
- /// Sets the activation after linear operations to `tanh`.
- RNNOptions& tanh();
- /// Sets the activation after linear operations to `relu`.
- RNNOptions& relu();
-
- /// The number of features of a single sample in the input sequence `x`.
+ /// The number of expected features in the input `x`
TORCH_ARG(int64_t, input_size);
- /// The number of features in the hidden state `h`.
+ /// The number of features in the hidden state `h`
TORCH_ARG(int64_t, hidden_size);
- /// The number of recurrent layers (cells) to use.
- TORCH_ARG(int64_t, layers) = 1;
- /// Whether a bias term should be added to all linear operations.
- TORCH_ARG(bool, with_bias) = true;
- /// If non-zero, adds dropout with the given probability to the output of each
- /// RNN layer, except the final layer.
- TORCH_ARG(double, dropout) = 0.0;
- /// Whether to make the RNN bidirectional.
- TORCH_ARG(bool, bidirectional) = false;
- /// If true, the input sequence should be provided as `(batch, sequence,
- /// features)`. If false (default), the expected layout is `(sequence, batch,
- /// features)`.
+ /// Number of recurrent layers. E.g., setting ``num_layers=2``
+ /// would mean stacking two RNNs together to form a `stacked RNN`,
+ /// with the second RNN taking in outputs of the first RNN and
+ /// computing the final results. Default: 1
+ TORCH_ARG(int64_t, num_layers) = 1;
+ /// The non-linearity to use. Can be either ``torch::kTanh`` or ``torch::kReLU``. Default: ``torch::kTanh``
+ TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
+ /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ /// Default: ``true``
+ TORCH_ARG(bool, bias) = true;
+ /// If ``true``, then the input and output tensors are provided
+ /// as `(batch, seq, feature)`. Default: ``false``
TORCH_ARG(bool, batch_first) = false;
- /// The activation to use after linear operations.
- TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
+ /// If non-zero, introduces a `Dropout` layer on the outputs of each
+ /// RNN layer except the last layer, with dropout probability equal to
+ /// `dropout`. Default: 0
+ TORCH_ARG(double, dropout) = 0.0;
+ /// If ``true``, becomes a bidirectional RNN. Default: ``false``
+ TORCH_ARG(bool, bidirectional) = false;
};
-using LSTMOptions = detail::RNNOptionsBase;
-using GRUOptions = detail::RNNOptionsBase;
+/// Options for the `LSTM` module.
+///
+/// Example:
+/// ```
+/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
+struct TORCH_API LSTMOptions {
+ LSTMOptions(int64_t input_size, int64_t hidden_size);
+
+ /// The number of expected features in the input `x`
+ TORCH_ARG(int64_t, input_size);
+ /// The number of features in the hidden state `h`
+ TORCH_ARG(int64_t, hidden_size);
+ /// Number of recurrent layers. E.g., setting ``num_layers=2``
+ /// would mean stacking two LSTMs together to form a `stacked LSTM`,
+ /// with the second LSTM taking in outputs of the first LSTM and
+ /// computing the final results. Default: 1
+ TORCH_ARG(int64_t, num_layers) = 1;
+ /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ /// Default: ``true``
+ TORCH_ARG(bool, bias) = true;
+ /// If ``true``, then the input and output tensors are provided
+ /// as (batch, seq, feature). Default: ``false``
+ TORCH_ARG(bool, batch_first) = false;
+ /// If non-zero, introduces a `Dropout` layer on the outputs of each
+ /// LSTM layer except the last layer, with dropout probability equal to
+ /// `dropout`. Default: 0
+ TORCH_ARG(double, dropout) = 0.0;
+ /// If ``true``, becomes a bidirectional LSTM. Default: ``false``
+ TORCH_ARG(bool, bidirectional) = false;
+};
+
+/// Options for the `GRU` module.
+///
+/// Example:
+/// ```
+/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
+struct TORCH_API GRUOptions {
+ GRUOptions(int64_t input_size, int64_t hidden_size);
+
+ /// The number of expected features in the input `x`
+ TORCH_ARG(int64_t, input_size);
+ /// The number of features in the hidden state `h`
+ TORCH_ARG(int64_t, hidden_size);
+ /// Number of recurrent layers. E.g., setting ``num_layers=2``
+ /// would mean stacking two GRUs together to form a `stacked GRU`,
+ /// with the second GRU taking in outputs of the first GRU and
+ /// computing the final results. Default: 1
+ TORCH_ARG(int64_t, num_layers) = 1;
+ /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
+ /// Default: ``true``
+ TORCH_ARG(bool, bias) = true;
+ /// If ``true``, then the input and output tensors are provided
+ /// as (batch, seq, feature). Default: ``false``
+ TORCH_ARG(bool, batch_first) = false;
+ /// If non-zero, introduces a `Dropout` layer on the outputs of each
+ /// GRU layer except the last layer, with dropout probability equal to
+ /// `dropout`. Default: 0
+ TORCH_ARG(double, dropout) = 0.0;
+ /// If ``true``, becomes a bidirectional GRU. Default: ``false``
+ TORCH_ARG(bool, bidirectional) = false;
+};
namespace detail {
diff --git a/torch/csrc/api/src/enum.cpp b/torch/csrc/api/src/enum.cpp
index de15449..28bd25b 100644
--- a/torch/csrc/api/src/enum.cpp
+++ b/torch/csrc/api/src/enum.cpp
@@ -30,3 +30,7 @@
TORCH_ENUM_DEFINE(Zeros)
TORCH_ENUM_DEFINE(Border)
TORCH_ENUM_DEFINE(Reflection)
+TORCH_ENUM_DEFINE(RNN_TANH)
+TORCH_ENUM_DEFINE(RNN_RELU)
+TORCH_ENUM_DEFINE(LSTM)
+TORCH_ENUM_DEFINE(GRU)
diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp
index 7f8aa39..0e54fda 100644
--- a/torch/csrc/api/src/nn/modules/rnn.cpp
+++ b/torch/csrc/api/src/nn/modules/rnn.cpp
@@ -1,6 +1,5 @@
#include <torch/nn/modules/rnn.h>
-#include <torch/nn/modules/dropout.h>
#include <torch/nn/init.h>
#include <torch/types.h>
#include <torch/utils.h>
@@ -19,65 +18,193 @@
#include <utility>
#include <vector>
+using namespace torch::nn::utils::rnn;
+
namespace torch {
namespace nn {
+
+/// These must line up with the CUDNN mode codes:
+/// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
+enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
+
+CuDNNMode get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
+ if (c10::get_if<enumtype::kRNN_RELU>(&mode)) {
+ return CuDNNMode::RNN_RELU;
+ } else if (c10::get_if<enumtype::kRNN_TANH>(&mode)) {
+ return CuDNNMode::RNN_TANH;
+ } else if (c10::get_if<enumtype::kLSTM>(&mode)) {
+ return CuDNNMode::LSTM;
+ } else if (c10::get_if<enumtype::kGRU>(&mode)) {
+ return CuDNNMode::GRU;
+ } else {
+ TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
+ }
+}
+
+Tensor apply_permutation(const Tensor& tensor, const Tensor& permutation, int64_t dim = 1) {
+ return tensor.index_select(dim, permutation);
+}
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
namespace detail {
template <typename Derived>
-RNNImplBase<Derived>::RNNImplBase(
- const RNNOptionsBase& options_,
- optional<CuDNNMode> cudnn_mode,
- int64_t number_of_gates)
- : options(options_),
- number_of_gates_(number_of_gates),
- cudnn_mode_(std::move(cudnn_mode)) {
+RNNImplBase<Derived>::RNNImplBase(const RNNOptionsBase& options_)
+ : options_base(options_) {
reset();
}
template <typename Derived>
void RNNImplBase<Derived>::reset() {
- const auto num_directions = options.bidirectional() ? 2 : 1;
+ const int64_t num_directions = options_base.bidirectional() ? 2 : 1;
- w_ih.resize(options.layers() * num_directions);
- w_hh.resize(options.layers() * num_directions);
- b_ih.resize(options.layers() * num_directions);
- b_hh.resize(options.layers() * num_directions);
+ TORCH_CHECK(
+ 0 <= options_base.dropout() && options_base.dropout() <= 1,
+ "dropout should be a number in range [0, 1] ",
+ "representing the probability of an element being ",
+ "zeroed");
- const int64_t gate_size = options.hidden_size() * number_of_gates_;
+ if (options_base.dropout() > 0 && options_base.num_layers() == 1) {
+ TORCH_WARN(
+ "dropout option adds dropout after all but last ",
+ "recurrent layer, so non-zero dropout expects ",
+ "num_layers greater than 1, but got dropout=", options_base.dropout(), " and ",
+ "num_layers=", options_base.num_layers());
+ }
- for (int64_t layer = 0; layer < options.layers(); ++layer) {
- for (auto direction = 0; direction < num_directions; direction++) {
- const auto layer_input_size = layer == 0 ? options.input_size() :
- options.hidden_size() * num_directions;
- const auto suffix = direction == 1 ? "_reverse" : "";
- const auto layer_idx = (layer * num_directions) + direction;
- w_ih[layer_idx] = this->register_parameter(
- "weight_ih_l" + std::to_string(layer) + suffix,
- torch::empty({gate_size, layer_input_size}));
- w_hh[layer_idx] = this->register_parameter(
- "weight_hh_l" + std::to_string(layer) + suffix,
- torch::empty({gate_size, options.hidden_size()}));
+ int64_t gate_size = 0;
+ if (c10::get_if<enumtype::kLSTM>(&options_base.mode())) {
+ gate_size = 4 * options_base.hidden_size();
+ } else if (c10::get_if<enumtype::kGRU>(&options_base.mode())) {
+ gate_size = 3 * options_base.hidden_size();
+ } else if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
+ gate_size = options_base.hidden_size();
+ } else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
+ gate_size = options_base.hidden_size();
+ } else {
+ TORCH_CHECK(false, "Unrecognized RNN mode: " + torch::enumtype::get_enum_name(options_base.mode()));
+ }
- if (options.with_bias()) {
- b_ih[layer_idx] = this->register_parameter(
- "bias_ih_l" + std::to_string(layer) + suffix,
- torch::empty({gate_size}));
- b_hh[layer_idx] = this->register_parameter(
- "bias_hh_l" + std::to_string(layer) + suffix,
- torch::empty({gate_size}));
+ flat_weights_names_ = {};
+ all_weights_ = {};
+
+ for (int64_t layer = 0; layer < options_base.num_layers(); layer++) {
+ for (int64_t direction = 0; direction < num_directions; direction++) {
+ int64_t layer_input_size = layer == 0 ? options_base.input_size() : options_base.hidden_size() * num_directions;
+
+ auto w_ih = torch::empty({gate_size, layer_input_size});
+ auto w_hh = torch::empty({gate_size, options_base.hidden_size()});
+ auto b_ih = torch::empty({gate_size});
+ // Second bias vector included for CuDNN compatibility. Only one
+ // bias vector is needed in standard definition.
+ auto b_hh = torch::empty({gate_size});
+ std::vector<Tensor> layer_params = {w_ih, w_hh, b_ih, b_hh};
+
+ std::string suffix = direction == 1 ? "_reverse" : "";
+ std::vector<std::string> param_names = {"weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"};
+ if (options_base.bias()) {
+ param_names.emplace_back("bias_ih_l{layer}{suffix}");
+ param_names.emplace_back("bias_hh_l{layer}{suffix}");
}
+ for (size_t i = 0; i < param_names.size(); i++) { // NOLINT(modernize-loop-convert)
+ std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer));
+ x = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
+ param_names[i] = x;
+ }
+
+ for (size_t i = 0; i < param_names.size(); i++) {
+ auto name = param_names[i];
+ auto param = layer_params[i];
+ this->register_parameter(name, param);
+ }
+ flat_weights_names_.insert(flat_weights_names_.end(), param_names.begin(), param_names.end());
+ all_weights_.emplace_back(param_names);
}
}
+ flat_weights_ = {};
+ for (const auto& wn : flat_weights_names_) {
+ auto named_parameters = this->named_parameters(/*recurse=*/false);
+ if (named_parameters.contains(wn)) {
+ flat_weights_.emplace_back(named_parameters[wn]);
+ } else {
+ flat_weights_.emplace_back(Tensor());
+ }
+ }
+
+ this->flatten_parameters();
+ this->reset_parameters();
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::flatten_parameters() {
+ // Resets parameter data pointer so that they can use faster code paths.
+ //
+ // Right now, this works only if the module is on the GPU and cuDNN is enabled.
+ // Otherwise, it's a no-op.
+
+ // Short-circuits if flat_weights_ is only partially instantiated
+ if (flat_weights_.size() != flat_weights_names_.size()) {
+ return;
+ }
+
+ // Short-circuits if any tensor in self.flat_weights_ is not acceptable to cuDNN
+ // or the tensors in flat_weights_ are of different dtypes
+
+ auto first_fw = flat_weights_[0];
+ auto dtype = first_fw.dtype();
+ for (const auto& fw : flat_weights_) {
+ if (!(fw.dtype() == dtype) ||
+ !fw.is_cuda() ||
+ !torch::cudnn_is_acceptable(fw)) {
+ return;
+ }
+ }
+
+ // If any parameters alias, we fall back to the slower, copying code path. This is
+ // a sufficient check, because overlapping parameter buffers that don't completely
+ // alias would break the assumptions of the uniqueness check in
+ // Module::named_parameters().
+ std::unordered_set<void*> unique_data_ptrs;
+ for (const auto& p : flat_weights_) {
+ unique_data_ptrs.emplace(p.data_ptr());
+ }
+ if (unique_data_ptrs.size() != flat_weights_.size()) {
+ return;
+ }
+
{
- NoGradGuard no_grad;
- const auto stdv = 1.0 / std::sqrt(options.hidden_size());
- for (auto& p : this->parameters()) {
- p.uniform_(-stdv, stdv);
+ torch::DeviceGuard device_guard(first_fw.device());
+
+ // Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
+ // an inplace operation on self.flat_weights_
+ {
+ torch::NoGradGuard no_grad;
+ if (torch::_use_cudnn_rnn_flatten_weight()) {
+ torch::_cudnn_rnn_flatten_weight(
+ flat_weights_,
+ options_base.bias() ? 4 : 2,
+ options_base.input_size(),
+ static_cast<int64_t>(get_cudnn_mode_for_rnn(options_base.mode())),
+ options_base.hidden_size(),
+ options_base.num_layers(),
+ options_base.batch_first(),
+ options_base.bidirectional());
+ }
}
}
+}
- flatten_parameters();
+template <typename Derived>
+void RNNImplBase<Derived>::reset_flat_weights() {
+ flat_weights_ = {};
+ for (const auto& wn : flat_weights_names_) {
+ auto named_parameters = this->named_parameters(/*recurse=*/false);
+ if (named_parameters.contains(wn)) {
+ flat_weights_.emplace_back(named_parameters[wn]);
+ } else {
+ flat_weights_.emplace_back(Tensor());
+ }
+ }
}
template <typename Derived>
@@ -86,126 +213,113 @@
torch::Dtype dtype,
bool non_blocking) {
nn::Module::to(device, dtype, non_blocking);
+ reset_flat_weights();
flatten_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
nn::Module::to(dtype, non_blocking);
+ reset_flat_weights();
flatten_parameters();
}
template <typename Derived>
void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
nn::Module::to(device, non_blocking);
- const auto num_directions = options.bidirectional() ? 2 : 1;
- for (int64_t layer = 0; layer < options.layers(); layer++) {
- for (auto direction = 0; direction < num_directions; direction++) {
- const auto layer_idx = (layer * num_directions) + direction;
- w_ih[layer_idx] = w_ih[layer_idx].to(device, non_blocking);
- w_hh[layer_idx] = w_hh[layer_idx].to(device, non_blocking);
- if (options.with_bias()) {
- b_ih[layer_idx] = b_ih[layer_idx].to(device, non_blocking);
- b_hh[layer_idx] = b_hh[layer_idx].to(device, non_blocking);
- }
- }
- }
+ reset_flat_weights();
flatten_parameters();
}
template <typename Derived>
+void RNNImplBase<Derived>::reset_parameters() {
+ const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
+ for (auto& weight : this->parameters()) {
+ init::uniform_(weight, -stdv, stdv);
+ }
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::check_input(const Tensor& input, const Tensor& batch_sizes) const {
+ int64_t expected_input_dim = batch_sizes.defined() ? 2 : 3;
+ TORCH_CHECK(
+ input.dim() == expected_input_dim,
+ "input must have ", expected_input_dim, " dimensions, got ", input.dim());
+ TORCH_CHECK(
+ options_base.input_size() == input.size(-1),
+ "input.size(-1) must be equal to input_size. Expected ", options_base.input_size(), ", got ", input.size(-1));
+}
+
+template <typename Derived>
+std::tuple<int64_t, int64_t, int64_t> RNNImplBase<Derived>::get_expected_hidden_size(
+ const Tensor& input, const Tensor& batch_sizes) const {
+ int64_t mini_batch = 0;
+ if (batch_sizes.defined()) {
+ mini_batch = batch_sizes[0].item<int64_t>();
+ } else {
+ mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
+ }
+ int64_t num_directions = options_base.bidirectional() ? 2 : 1;
+ return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size());
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::check_hidden_size(
+ const Tensor& hx,
+ std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
+ std::string msg) const {
+ auto expected_hidden_size_vec = std::vector<int64_t>({
+ std::get<0>(expected_hidden_size),
+ std::get<1>(expected_hidden_size),
+ std::get<2>(expected_hidden_size),
+ });
+ if (hx.sizes() != expected_hidden_size_vec) {
+ msg = std::regex_replace(msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec));
+ msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes()));
+ TORCH_CHECK(false, msg);
+ }
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const {
+ this->check_input(input, batch_sizes);
+ auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
+
+ this->check_hidden_size(hidden, expected_hidden_size);
+}
+
+template <typename Derived>
+Tensor RNNImplBase<Derived>::permute_hidden(Tensor hx, const Tensor& permutation) const {
+ if (!permutation.defined()) {
+ return hx;
+ }
+ return apply_permutation(hx, permutation);
+}
+
+template <typename Derived>
void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
const std::string name = this->name();
const std::string name_without_impl = name.substr(0, name.size() - 4);
- stream << name_without_impl << "(input_size=" << options.input_size()
- << ", hidden_size=" << options.hidden_size()
- << ", layers=" << options.layers() << ", dropout=" << options.dropout()
+ stream << std::boolalpha << name_without_impl << "(input_size=" << options_base.input_size()
+ << ", hidden_size=" << options_base.hidden_size()
+ << ", num_layers=" << options_base.num_layers()
+ << ", bias=" << options_base.bias()
+ << ", batch_first=" << options_base.batch_first()
+ << ", dropout=" << options_base.dropout()
+ << ", bidirectional=" << options_base.bidirectional()
<< ")";
}
template <typename Derived>
-void RNNImplBase<Derived>::flatten_parameters() {
- // Cache the flattened weight and bias vector.
- flat_weights_ = flat_weights();
-
- if (!cudnn_mode_ || !torch::cudnn_is_acceptable(w_ih.at(0))) {
- return;
- }
-
- NoGradGuard no_grad;
- if (torch::_use_cudnn_rnn_flatten_weight()) {
- torch::_cudnn_rnn_flatten_weight(
- flat_weights_,
- /*weight_stride0=*/options.with_bias() ? 4 : 2,
- options.input_size(),
- static_cast<int64_t>(*cudnn_mode_),
- options.hidden_size(),
- options.layers(),
- /*batch_first=*/options.batch_first(),
- /*bidirectional=*/options.bidirectional());
- }
-}
-
-template <typename Derived>
-RNNOutput RNNImplBase<Derived>::generic_forward(
- std::function<RNNFunctionSignature> function,
- const Tensor& input,
- Tensor state) {
- if (!state.defined()) {
- // #layers, batch size, state size
- const auto batch_size = input.size(options.batch_first() ? 0 : 1);
- const auto num_directions = options.bidirectional() ? 2 : 1;
- state = torch::zeros(
- {options.layers() * num_directions, batch_size, options.hidden_size()},
- input.options());
- }
- Tensor output, new_state;
- std::tie(output, new_state) = function(
- input,
- std::move(state),
- flat_weights_,
- options.with_bias(),
- options.layers(),
- options.dropout(),
- this->is_training(),
- options.bidirectional(),
- options.batch_first());
- return {output, new_state};
-}
-
-template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::flat_weights() const {
- // Organize all weights in a flat vector in the order
- // (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other).
- std::vector<Tensor> flat;
- const auto num_directions = options.bidirectional() ? 2 : 1;
- for (int64_t layer = 0; layer < options.layers(); layer++) {
- for (auto direction = 0; direction < num_directions; direction++) {
- const auto layer_idx = (layer * num_directions) + direction;
- flat.push_back(w_ih[layer_idx]);
- flat.push_back(w_hh[layer_idx]);
- if (options.with_bias()) {
- flat.push_back(b_ih[layer_idx]);
- flat.push_back(b_hh[layer_idx]);
- }
+std::vector<Tensor> RNNImplBase<Derived>::all_weights() const {
+ std::vector<Tensor> result = {};
+ auto named_parameters = this->named_parameters(/*recurse=*/false);
+ for (const auto& weights : all_weights_) {
+ for (const auto& weight : weights) {
+ result.emplace_back(named_parameters[weight]);
}
}
- return flat;
-}
-
-template <typename Derived>
-bool RNNImplBase<Derived>::any_parameters_alias() const {
- // If any parameters alias, we fall back to the slower, copying code path.
- // This is a sufficient check, because overlapping parameter buffers that
- // don't completely alias would break the assumptions of the uniqueness check
- // in Module.named_parameters().
- std::unordered_set<void*> unique_data_ptrs;
- auto params = this->parameters();
- unique_data_ptrs.reserve(params.size());
- for (const auto& p : params) {
- unique_data_ptrs.emplace(p.data_ptr());
- }
- return unique_data_ptrs.size() != params.size();
+ return result;
}
template class RNNImplBase<LSTMImpl>;
@@ -215,91 +329,275 @@
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-RNNImpl::RNNImpl(const RNNOptions& options_)
- : detail::RNNImplBase<RNNImpl>(
- detail::RNNOptionsBase(options_.input_size(), options_.hidden_size())
- .layers(options_.layers())
- .with_bias(options_.with_bias())
- .dropout(options_.dropout())
- .bidirectional(options_.bidirectional())
- .batch_first(options_.batch_first()),
- static_cast<CuDNNMode>(options_.activation())),
- options(options_) {}
-
-void RNNImpl::pretty_print(std::ostream& stream) const {
- stream << "torch::nn::RNN(input_size=" << options.input_size()
- << ", hidden_size=" << options.hidden_size()
- << ", layers=" << options.layers() << ", dropout=" << options.dropout()
- << ", activation="
- << (options.activation() == RNNActivation::Tanh ? "tanh" : "relu")
- << ")";
+detail::RNNOptionsBase::rnn_options_base_mode_t compute_rnn_options_base_mode(
+ RNNOptions::nonlinearity_t nonlinearity) {
+ if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
+ return torch::kRNN_TANH;
+ } else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
+ return torch::kRNN_RELU;
+ } else {
+ TORCH_CHECK(false, "Unknown nonlinearity ", torch::enumtype::get_enum_name(nonlinearity));
+ }
}
-RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
- switch (options.activation()) {
- case RNNActivation::ReLU:
- return generic_forward(
- static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
- input,
- std::move(state));
- case RNNActivation::Tanh:
- return generic_forward(
- static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
- input,
- std::move(state));
- default:
- AT_ERROR("Unhandled RNN activation function!");
+RNNImpl::RNNImpl(const RNNOptions& options_)
+ : detail::RNNImplBase<RNNImpl>(
+ detail::RNNOptionsBase(
+ compute_rnn_options_base_mode(options_.nonlinearity()),
+ options_.input_size(),
+ options_.hidden_size())
+ .num_layers(options_.num_layers())
+ .bias(options_.bias())
+ .batch_first(options_.batch_first())
+ .dropout(options_.dropout())
+ .bidirectional(options_.bidirectional())),
+ options(options_) {}
+
+std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
+ const Tensor& input,
+ const Tensor& batch_sizes,
+ const Tensor& sorted_indices,
+ int64_t max_batch_size,
+ Tensor hx) {
+ if (!hx.defined()) {
+ int64_t num_directions = options_base.bidirectional() ? 2 : 1;
+ hx = torch::zeros({options_base.num_layers() * num_directions,
+ max_batch_size, options_base.hidden_size()},
+ torch::dtype(input.dtype()).device(input.device()));
+ } else {
+ // Each batch of the hidden state should match the input sequence that
+ // the user believes he/she is passing in.
+ hx = this->permute_hidden(hx, sorted_indices);
+ }
+
+ this->check_forward_args(input, hx, batch_sizes);
+
+ std::tuple<Tensor, Tensor> result;
+ if (!batch_sizes.defined()) {
+ if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
+ result = torch::rnn_tanh(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
+ options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
+ } else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
+ result = torch::rnn_relu(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
+ options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
+ } else {
+ TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
+ }
+ } else {
+ if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
+ result = torch::rnn_tanh(input, batch_sizes, hx, flat_weights_, options_base.bias(),
+ options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
+ } else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
+ result = torch::rnn_relu(input, batch_sizes, hx, flat_weights_, options_base.bias(),
+ options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
+ } else {
+ TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
+ }
}
+ auto output = std::get<0>(result);
+ auto hidden = std::get<1>(result);
+
+ return std::make_tuple(output, hidden);
+}
+
+std::tuple<Tensor, Tensor> RNNImpl::forward(const Tensor& input, Tensor hx) {
+ auto batch_sizes = torch::Tensor();
+ auto max_batch_size = options_base.batch_first() ? input.size(0) : input.size(1);
+ auto sorted_indices = torch::Tensor();
+ auto unsorted_indices = torch::Tensor();
+
+ Tensor output, hidden;
+ std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+ return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
+}
+
+std::tuple<PackedSequence, Tensor> RNNImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
+ const auto& input = packed_input.data();
+ const auto& batch_sizes = packed_input.batch_sizes();
+ const auto& sorted_indices = packed_input.sorted_indices();
+ const auto& unsorted_indices = packed_input.unsorted_indices();
+ auto max_batch_size = batch_sizes[0].item<int64_t>();
+
+ Tensor output, hidden;
+ std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+ auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
+ return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
LSTMImpl::LSTMImpl(const LSTMOptions& options_)
: detail::RNNImplBase<LSTMImpl>(
- options_,
- CuDNNMode::LSTM,
- /*number_of_gates=*/4) {}
+ detail::RNNOptionsBase(
+ torch::kLSTM,
+ options_.input_size(),
+ options_.hidden_size())
+ .num_layers(options_.num_layers())
+ .bias(options_.bias())
+ .batch_first(options_.batch_first())
+ .dropout(options_.dropout())
+ .bidirectional(options_.bidirectional())),
+ options(options_) {}
-RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
- // It would be trickier to adapt the `generic_forward` for the LSTM because
- // its output has a different dimensionality (3-tuple vs. 2-tuple), while we
- // always return one state variable (stacking the hidden/cell state into one),
- // which also makes the state variables going into the `generic_forward`, and
- // the way we default-initialize the state when it is not passed, slightly
- // different. So we just re-implement it specifically for the LSTM here.
- if (!state.defined()) {
- // 2 for hidden state and cell state, then #layers, batch size, state size
- const auto batch_size = input.size(options.batch_first() ? 0 : 1);
- const auto num_directions = options.bidirectional() ? 2 : 1;
- state = torch::zeros(
- {2, options.layers() * num_directions, batch_size, options.hidden_size()},
- input.options());
+void LSTMImpl::check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const {
+ this->check_input(input, batch_sizes);
+ auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
+
+ this->check_hidden_size(std::get<0>(hidden), expected_hidden_size,
+ "Expected hidden[0] size {1}, got {2}");
+ this->check_hidden_size(std::get<1>(hidden), expected_hidden_size,
+ "Expected hidden[1] size {1}, got {2}");
+}
+
+std::tuple<Tensor, Tensor> LSTMImpl::permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const {
+ if (!permutation.defined()) {
+ return hx;
}
- Tensor output, hidden_state, cell_state;
- std::tie(output, hidden_state, cell_state) = torch::lstm(
- input,
- {state[0], state[1]},
- flat_weights_,
- options.with_bias(),
- options.layers(),
- options.dropout(),
- this->is_training(),
- options.bidirectional(),
- options.batch_first());
- return {output, torch::stack({hidden_state, cell_state})};
+ return std::make_tuple(
+ apply_permutation(std::get<0>(hx), permutation),
+ apply_permutation(std::get<1>(hx), permutation)
+ );
+}
+
+std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
+ const Tensor& input,
+ const Tensor& batch_sizes,
+ const Tensor& sorted_indices,
+ int64_t max_batch_size,
+ torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
+
+ std::tuple<Tensor, Tensor> hx;
+ if (!hx_opt.has_value()) {
+ int64_t num_directions = options.bidirectional() ? 2 : 1;
+ auto zeros = torch::zeros({options.num_layers() * num_directions,
+ max_batch_size, options.hidden_size()},
+ torch::dtype(input.dtype()).device(input.device()));
+ hx = std::make_tuple(zeros, zeros);
+ } else {
+ hx = hx_opt.value();
+ // Each batch of the hidden state should match the input sequence that
+ // the user believes he/she is passing in.
+ hx = this->permute_hidden(hx, sorted_indices);
+ }
+
+ this->check_forward_args(input, hx, batch_sizes);
+ std::tuple<Tensor, Tensor, Tensor> result;
+ if (!batch_sizes.defined()) {
+ result = torch::lstm(input, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(), options.num_layers(),
+ options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
+ } else {
+ result = torch::lstm(input, batch_sizes, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(),
+ options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
+ }
+ auto output = std::get<0>(result);
+ auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result));
+
+ return std::make_tuple(output, hidden);
+}
+
+std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
+ const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
+ auto batch_sizes = torch::Tensor();
+ auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
+ auto sorted_indices = torch::Tensor();
+ auto unsorted_indices = torch::Tensor();
+
+ Tensor output;
+ std::tuple<Tensor, Tensor> hidden;
+ std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
+
+ return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
+}
+
+std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::forward_with_packed_input(
+ const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
+ const auto& input = packed_input.data();
+ const auto& batch_sizes = packed_input.batch_sizes();
+ const auto& sorted_indices = packed_input.sorted_indices();
+ const auto& unsorted_indices = packed_input.unsorted_indices();
+ auto max_batch_size = batch_sizes[0].item<int64_t>();
+
+ Tensor output;
+ std::tuple<Tensor, Tensor> hidden;
+ std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
+
+ auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
+ return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
GRUImpl::GRUImpl(const GRUOptions& options_)
: detail::RNNImplBase<GRUImpl>(
- options_,
- CuDNNMode::GRU,
- /*number_of_gates=*/3) {}
+ detail::RNNOptionsBase(
+ torch::kGRU,
+ options_.input_size(),
+ options_.hidden_size())
+ .num_layers(options_.num_layers())
+ .bias(options_.bias())
+ .batch_first(options_.batch_first())
+ .dropout(options_.dropout())
+ .bidirectional(options_.bidirectional())),
+ options(options_) {}
-RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
- return generic_forward(
- static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
+std::tuple<Tensor, Tensor> GRUImpl::forward_helper(
+ const Tensor& input,
+ const Tensor& batch_sizes,
+ const Tensor& sorted_indices,
+ int64_t max_batch_size,
+ Tensor hx) {
+ if (!hx.defined()) {
+ int64_t num_directions = options.bidirectional() ? 2 : 1;
+ hx = torch::zeros({options.num_layers() * num_directions,
+ max_batch_size, options.hidden_size()},
+ torch::dtype(input.dtype()).device(input.device()));
+ } else {
+ // Each batch of the hidden state should match the input sequence that
+ // the user believes he/she is passing in.
+ hx = this->permute_hidden(hx, sorted_indices);
+ }
+
+ this->check_forward_args(input, hx, batch_sizes);
+ std::tuple<Tensor, Tensor> result;
+ if (!batch_sizes.defined()) {
+ result = torch::gru(input, hx, flat_weights_, options.bias(), options.num_layers(),
+ options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
+ } else {
+ result = torch::gru(input, batch_sizes, hx, flat_weights_, options.bias(),
+ options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
+ }
+ auto output = std::get<0>(result);
+ auto hidden = std::get<1>(result);
+
+ return std::make_tuple(output, hidden);
+}
+
+std::tuple<Tensor, Tensor> GRUImpl::forward(const Tensor& input, Tensor hx) {
+ auto batch_sizes = torch::Tensor();
+ auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
+ auto sorted_indices = torch::Tensor();
+ auto unsorted_indices = torch::Tensor();
+
+ Tensor output, hidden;
+ std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+ return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
+}
+
+std::tuple<PackedSequence, Tensor> GRUImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
+ const auto& input = packed_input.data();
+ const auto& batch_sizes = packed_input.batch_sizes();
+ const auto& sorted_indices = packed_input.sorted_indices();
+ const auto& unsorted_indices = packed_input.unsorted_indices();
+ auto max_batch_size = batch_sizes[0].item<int64_t>();
+
+ Tensor output, hidden;
+ std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+ auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
+ return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
}
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/torch/csrc/api/src/nn/options/rnn.cpp b/torch/csrc/api/src/nn/options/rnn.cpp
index 33b6d53..d4a0841 100644
--- a/torch/csrc/api/src/nn/options/rnn.cpp
+++ b/torch/csrc/api/src/nn/options/rnn.cpp
@@ -5,21 +5,19 @@
namespace detail {
-RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size)
- : input_size_(input_size), hidden_size_(hidden_size) {}
+RNNOptionsBase::RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size)
+ : mode_(mode), input_size_(input_size), hidden_size_(hidden_size) {}
} // namespace detail
RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
: input_size_(input_size), hidden_size_(hidden_size) {}
-RNNOptions& RNNOptions::tanh() {
- return activation(RNNActivation::Tanh);
-}
+LSTMOptions::LSTMOptions(int64_t input_size, int64_t hidden_size)
+ : input_size_(input_size), hidden_size_(hidden_size) {}
-RNNOptions& RNNOptions::relu() {
- return activation(RNNActivation::ReLU);
-}
+GRUOptions::GRUOptions(int64_t input_size, int64_t hidden_size)
+ : input_size_(input_size), hidden_size_(hidden_size) {}
namespace detail {