[C++ API] RNN / GRU / LSTM layer refactoring (#34322)

Summary:
This PR refactors RNN / GRU / LSTM layers in C++ API to exactly match the implementation in Python API.

**BC-breaking changes:**
- Instead of returning `RNNOutput`, RNN / GRU forward method now returns `std::tuple<Tensor, Tensor>`, and LSTM forward method now returns `std::tuple<Tensor, std::tuple<Tensor, Tensor>>`, matching Python API.
- RNN / LSTM / GRU forward method now accepts the same inputs (input tensor and optionally hidden state), matching Python API.
- RNN / LSTM / GRU now has `forward_with_packed_input` method which accepts `PackedSequence` as input and optionally hidden state, matching the `forward(PackedSequence, ...)` variant in Python API.
- In `RNNOptions`
    - `tanh()` / `relu()` / `activation` are removed. Instead, `nonlinearity` is added which takes either `torch::kTanh` or `torch::kReLU`
    - `layers` -> `num_layers`
    - `with_bias` -> `bias`
- In `LSTMOptions`
    - `layers` -> `num_layers`
    - `with_bias` -> `bias`
- In `GRUOptions`
    - `layers` -> `num_layers`
    - `with_bias` -> `bias`

The majority of the changes in this PR focused on refactoring the implementations in `torch/csrc/api/src/nn/modules/rnn.cpp` to match the Python API. RNN tests are then changed to reflected the revised API design.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34322

Differential Revision: D20311699

Pulled By: yf225

fbshipit-source-id: e2b60fc7bac64367a8434647d74c08568a7b28f7
diff --git a/test/cpp/api/enum.cpp b/test/cpp/api/enum.cpp
index 12a8879..ac20edb 100644
--- a/test/cpp/api/enum.cpp
+++ b/test/cpp/api/enum.cpp
@@ -43,7 +43,11 @@
     torch::enumtype::kBatchMean,
     torch::enumtype::kZeros,
     torch::enumtype::kBorder,
-    torch::enumtype::kReflection
+    torch::enumtype::kReflection,
+    torch::enumtype::kRNN_TANH,
+    torch::enumtype::kRNN_RELU,
+    torch::enumtype::kLSTM,
+    torch::enumtype::kGRU
   > v;
 
   TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
@@ -76,4 +80,8 @@
   TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
   TORCH_ENUM_PRETTY_PRINT_TEST(Border)
   TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
+  TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
+  TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
+  TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
+  TORCH_ENUM_PRETTY_PRINT_TEST(GRU) 
 }
diff --git a/test/cpp/api/modulelist.cpp b/test/cpp/api/modulelist.cpp
index 3a688d0..5aa4ccb 100644
--- a/test/cpp/api/modulelist.cpp
+++ b/test/cpp/api/modulelist.cpp
@@ -283,7 +283,7 @@
       "  (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
       "  (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
-      "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+      "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
       ")");
 }
 
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index c17a0ab..c51c386 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -23,7 +23,7 @@
     auto B = x.size(1);
     x = x.view({T * B, 1});
     x = l1->forward(x).view({T, B, nhid}).tanh_();
-    x = rnn->forward(x).output[T - 1];
+    x = std::get<0>(rnn->forward(x))[T - 1];
     x = lo->forward(x);
     return x;
   };
@@ -61,29 +61,39 @@
   return true;
 };
 
-void check_lstm_sizes(RNNOutput output) {
+void check_lstm_sizes(std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output) {
   // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
   // 10 and 16 time steps (10 x 16 x n)
 
-  ASSERT_EQ(output.output.ndimension(), 3);
-  ASSERT_EQ(output.output.size(0), 10);
-  ASSERT_EQ(output.output.size(1), 16);
-  ASSERT_EQ(output.output.size(2), 64);
+  torch::Tensor output = std::get<0>(lstm_output);
+  std::tuple<torch::Tensor, torch::Tensor> state = std::get<1>(lstm_output);
+  torch::Tensor hx = std::get<0>(state);
+  torch::Tensor cx = std::get<1>(state);
 
-  ASSERT_EQ(output.state.ndimension(), 4);
-  ASSERT_EQ(output.state.size(0), 2); // (hx, cx)
-  ASSERT_EQ(output.state.size(1), 3); // layers
-  ASSERT_EQ(output.state.size(2), 16); // Batchsize
-  ASSERT_EQ(output.state.size(3), 64); // 64 hidden dims
+  ASSERT_EQ(output.ndimension(), 3);
+  ASSERT_EQ(output.size(0), 10);
+  ASSERT_EQ(output.size(1), 16);
+  ASSERT_EQ(output.size(2), 64);
+
+  ASSERT_EQ(hx.ndimension(), 3);
+  ASSERT_EQ(hx.size(0), 3); // layers
+  ASSERT_EQ(hx.size(1), 16); // Batchsize
+  ASSERT_EQ(hx.size(2), 64); // 64 hidden dims
+
+  ASSERT_EQ(cx.ndimension(), 3);
+  ASSERT_EQ(cx.size(0), 3); // layers
+  ASSERT_EQ(cx.size(1), 16); // Batchsize
+  ASSERT_EQ(cx.size(2), 64); // 64 hidden dims
 
   // Something is in the hiddens
-  ASSERT_GT(output.state.norm().item<float>(), 0);
+  ASSERT_GT(hx.norm().item<float>(), 0);
+  ASSERT_GT(cx.norm().item<float>(), 0);
 }
 
 struct RNNTest : torch::test::SeedingFixture {};
 
 TEST_F(RNNTest, CheckOutputSizes) {
-  LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
+  LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
   // Input size is: sequence length, batch size, input size
   auto x = torch::randn({10, 16, 128}, torch::requires_grad());
   auto output = model->forward(x);
@@ -92,11 +102,17 @@
   y.backward();
   check_lstm_sizes(output);
 
-  auto next = model->forward(x, output.state);
+  auto next = model->forward(x, std::get<1>(output));
 
   check_lstm_sizes(next);
 
-  torch::Tensor diff = next.state - output.state;
+  auto output_hx = std::get<0>(std::get<1>(output));
+  auto output_cx = std::get<1>(std::get<1>(output));
+
+  auto next_hx = std::get<0>(std::get<1>(next));
+  auto next_cx = std::get<1>(std::get<1>(next));
+
+  torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
 
   // Hiddens changed
   ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
@@ -122,12 +138,12 @@
   }
 
   auto out = model->forward(x);
-  ASSERT_EQ(out.output.ndimension(), 3);
-  ASSERT_EQ(out.output.size(0), 3);
-  ASSERT_EQ(out.output.size(1), 4);
-  ASSERT_EQ(out.output.size(2), 2);
+  ASSERT_EQ(std::get<0>(out).ndimension(), 3);
+  ASSERT_EQ(std::get<0>(out).size(0), 3);
+  ASSERT_EQ(std::get<0>(out).size(1), 4);
+  ASSERT_EQ(std::get<0>(out).size(2), 2);
 
-  auto flat = out.output.view(3 * 4 * 2);
+  auto flat = std::get<0>(out).view(3 * 4 * 2);
   float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
                    0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
                    0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
@@ -136,12 +152,20 @@
     ASSERT_LT(std::abs(flat[i].item<float>() - c_out[i]), 1e-3);
   }
 
-  ASSERT_EQ(out.state.ndimension(), 4); // (hx, cx) x layers x B x 2
-  ASSERT_EQ(out.state.size(0), 2);
-  ASSERT_EQ(out.state.size(1), 1);
-  ASSERT_EQ(out.state.size(2), 4);
-  ASSERT_EQ(out.state.size(3), 2);
-  flat = out.state.view(16);
+  auto hx = std::get<0>(std::get<1>(out));
+  auto cx = std::get<1>(std::get<1>(out));
+
+  ASSERT_EQ(hx.ndimension(), 3); // layers x B x 2
+  ASSERT_EQ(hx.size(0), 1);
+  ASSERT_EQ(hx.size(1), 4);
+  ASSERT_EQ(hx.size(2), 2);
+
+  ASSERT_EQ(cx.ndimension(), 3); // layers x B x 2
+  ASSERT_EQ(cx.size(0), 1);
+  ASSERT_EQ(cx.size(1), 4);
+  ASSERT_EQ(cx.size(2), 2);
+
+  flat = torch::cat({hx, cx}, 0).view(16);
   float h_out[] = {0.7889,
                    0.9003,
                    0.7769,
@@ -165,27 +189,27 @@
 
 TEST_F(RNNTest, EndToEndLSTM) {
   ASSERT_TRUE(test_RNN_xor<LSTM>(
-      [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }));
+      [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }));
 }
 
 TEST_F(RNNTest, EndToEndGRU) {
   ASSERT_TRUE(
-      test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).layers(2)); }));
+      test_RNN_xor<GRU>([](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }));
 }
 
 TEST_F(RNNTest, EndToEndRNNRelu) {
   ASSERT_TRUE(test_RNN_xor<RNN>(
-      [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }));
+      [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }));
 }
 
 TEST_F(RNNTest, EndToEndRNNTanh) {
   ASSERT_TRUE(test_RNN_xor<RNN>(
-      [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }));
+      [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }));
 }
 
 TEST_F(RNNTest, Sizes_CUDA) {
   torch::manual_seed(0);
-  LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
+  LSTM model(LSTMOptions(128, 64).num_layers(3).dropout(0.2));
   model->to(torch::kCUDA);
   auto x =
       torch::randn({10, 16, 128}, torch::requires_grad().device(torch::kCUDA));
@@ -195,11 +219,17 @@
   y.backward();
   check_lstm_sizes(output);
 
-  auto next = model->forward(x, output.state);
+  auto next = model->forward(x, std::get<1>(output));
 
   check_lstm_sizes(next);
 
-  torch::Tensor diff = next.state - output.state;
+  auto output_hx = std::get<0>(std::get<1>(output));
+  auto output_cx = std::get<1>(std::get<1>(output));
+
+  auto next_hx = std::get<0>(std::get<1>(next));
+  auto next_cx = std::get<1>(std::get<1>(next));
+
+  torch::Tensor diff = torch::cat({next_hx, next_cx}, 0) - torch::cat({output_hx, output_cx}, 0);
 
   // Hiddens changed
   ASSERT_GT(diff.abs().sum().item<float>(), 1e-3);
@@ -207,51 +237,68 @@
 
 TEST_F(RNNTest, EndToEndLSTM_CUDA) {
   ASSERT_TRUE(test_RNN_xor<LSTM>(
-      [](int s) { return LSTM(LSTMOptions(s, s).layers(2)); }, true));
+      [](int s) { return LSTM(LSTMOptions(s, s).num_layers(2)); }, true));
 }
 
 TEST_F(RNNTest, EndToEndGRU_CUDA) {
   ASSERT_TRUE(test_RNN_xor<GRU>(
-      [](int s) { return GRU(GRUOptions(s, s).layers(2)); }, true));
+      [](int s) { return GRU(GRUOptions(s, s).num_layers(2)); }, true));
 }
 
 TEST_F(RNNTest, EndToEndRNNRelu_CUDA) {
   ASSERT_TRUE(test_RNN_xor<RNN>(
-      [](int s) { return RNN(RNNOptions(s, s).relu().layers(2)); }, true));
+      [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kReLU).num_layers(2)); }, true));
 }
 TEST_F(RNNTest, EndToEndRNNTanh_CUDA) {
   ASSERT_TRUE(test_RNN_xor<RNN>(
-      [](int s) { return RNN(RNNOptions(s, s).tanh().layers(2)); }, true));
+      [](int s) { return RNN(RNNOptions(s, s).nonlinearity(torch::kTanh).num_layers(2)); }, true));
 }
 
 TEST_F(RNNTest, PrettyPrintRNNs) {
   ASSERT_EQ(
-      c10::str(LSTM(LSTMOptions(128, 64).layers(3).dropout(0.2))),
-      "torch::nn::LSTM(input_size=128, hidden_size=64, layers=3, dropout=0.2)");
+      c10::str(LSTM(LSTMOptions(128, 64).num_layers(3).dropout(0.2))),
+      "torch::nn::LSTM(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
   ASSERT_EQ(
-      c10::str(GRU(GRUOptions(128, 64).layers(3).dropout(0.5))),
-      "torch::nn::GRU(input_size=128, hidden_size=64, layers=3, dropout=0.5)");
+      c10::str(GRU(GRUOptions(128, 64).num_layers(3).dropout(0.5))),
+      "torch::nn::GRU(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.5, bidirectional=false)");
   ASSERT_EQ(
-      c10::str(RNN(RNNOptions(128, 64).layers(3).dropout(0.2).tanh())),
-      "torch::nn::RNN(input_size=128, hidden_size=64, layers=3, dropout=0.2, activation=tanh)");
+      c10::str(RNN(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh))),
+      "torch::nn::RNN(input_size=128, hidden_size=64, num_layers=3, bias=true, batch_first=false, dropout=0.2, bidirectional=false)");
 }
 
 // This test assures that flatten_parameters does not crash,
 // when bidirectional is set to true
 // https://github.com/pytorch/pytorch/issues/19545
 TEST_F(RNNTest, BidirectionalFlattenParameters) {
-  GRU gru(GRUOptions(100, 256).layers(2).bidirectional(true));
+  GRU gru(GRUOptions(100, 256).num_layers(2).bidirectional(true));
   gru->flatten_parameters();
 }
 
 template <typename Impl>
-void copyParameters(torch::nn::ModuleHolder<Impl>& target, size_t t_i,
-                    const torch::nn::ModuleHolder<Impl>& source, size_t s_i) {
+void copyParameters(torch::nn::ModuleHolder<Impl>& target, std::string t_suffix,
+                    const torch::nn::ModuleHolder<Impl>& source, std::string s_suffix) {
   at::NoGradGuard guard;
-  target->w_ih[t_i].copy_(source->w_ih[s_i]);
-  target->w_hh[t_i].copy_(source->w_hh[s_i]);
-  target->b_ih[t_i].copy_(source->b_ih[s_i]);
-  target->b_hh[t_i].copy_(source->b_hh[s_i]);
+  target->named_parameters()["weight_ih_l" + t_suffix].copy_(source->named_parameters()["weight_ih_l" + s_suffix]);
+  target->named_parameters()["weight_hh_l" + t_suffix].copy_(source->named_parameters()["weight_hh_l" + s_suffix]);
+  target->named_parameters()["bias_ih_l" + t_suffix].copy_(source->named_parameters()["bias_ih_l" + s_suffix]);
+  target->named_parameters()["bias_hh_l" + t_suffix].copy_(source->named_parameters()["bias_hh_l" + s_suffix]);
+}
+
+std::tuple<torch::Tensor, torch::Tensor> gru_output_to_device(
+  std::tuple<torch::Tensor, torch::Tensor> gru_output, torch::Device device) {
+  return std::make_tuple(
+    std::get<0>(gru_output).to(device),
+    std::get<1>(gru_output).to(device));
+}
+
+std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output_to_device(
+  std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>> lstm_output, torch::Device device) {
+  auto hidden_states = std::get<1>(lstm_output);
+  return std::make_tuple(
+    std::get<0>(lstm_output).to(device),
+    std::make_tuple(
+      std::get<0>(hidden_states).to(device),
+      std::get<1>(hidden_states).to(device)));
 }
 
 // This test is a port of python code introduced here:
@@ -264,7 +311,7 @@
   auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
   auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
 
-  auto gru_options = GRUOptions(1, 1).layers(1).batch_first(false);
+  auto gru_options = GRUOptions(1, 1).num_layers(1).batch_first(false);
   GRU bi_grus {gru_options.bidirectional(true)};
   GRU reverse_gru {gru_options.bidirectional(false)};
 
@@ -275,28 +322,26 @@
 
   // Now make sure the weights of the reverse gru layer match
   // ones of the (reversed) bidirectional's:
-  copyParameters(reverse_gru, 0, bi_grus, 1);
+  copyParameters(reverse_gru, "0", bi_grus, "0_reverse");
 
   auto bi_output = bi_grus->forward(input);
   auto reverse_output = reverse_gru->forward(input_reversed);
 
   if (cuda) {
-    bi_output.output = bi_output.output.to(torch::kCPU);
-    bi_output.state = bi_output.state.to(torch::kCPU);
-    reverse_output.output = reverse_output.output.to(torch::kCPU);
-    reverse_output.state = reverse_output.state.to(torch::kCPU);
+    bi_output = gru_output_to_device(bi_output, torch::kCPU);
+    reverse_output = gru_output_to_device(reverse_output, torch::kCPU);
   }
 
-  ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
-  auto size = bi_output.output.size(0);
+  ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
+  auto size = std::get<0>(bi_output).size(0);
   for (int i = 0; i < size; i++) {
-    ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
-              reverse_output.output[size - 1 - i][0][0].item<float>());
+    ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
+              std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
   }
   // The hidden states of the reversed GRUs sits
   // in the odd indices in the first dimension.
-  ASSERT_EQ(bi_output.state[1][0][0].item<float>(),
-            reverse_output.state[0][0][0].item<float>());
+  ASSERT_EQ(std::get<1>(bi_output)[1][0][0].item<float>(),
+            std::get<1>(reverse_output)[0][0][0].item<float>());
 }
 
 TEST_F(RNNTest, BidirectionalGRUReverseForward) {
@@ -315,7 +360,7 @@
   auto input = torch::tensor({1, 2, 3, 4, 5}, opt).reshape({5, 1, 1});
   auto input_reversed = torch::tensor({5, 4, 3, 2, 1}, opt).reshape({5, 1, 1});
 
-  auto lstm_opt = GRUOptions(1, 1).layers(1).batch_first(false);
+  auto lstm_opt = LSTMOptions(1, 1).num_layers(1).batch_first(false);
 
   LSTM bi_lstm {lstm_opt.bidirectional(true)};
   LSTM reverse_lstm {lstm_opt.bidirectional(false)};
@@ -327,30 +372,28 @@
 
   // Now make sure the weights of the reverse lstm layer match
   // ones of the (reversed) bidirectional's:
-  copyParameters(reverse_lstm, 0, bi_lstm, 1);
+  copyParameters(reverse_lstm, "0", bi_lstm, "0_reverse");
 
   auto bi_output = bi_lstm->forward(input);
   auto reverse_output = reverse_lstm->forward(input_reversed);
 
   if (cuda) {
-    bi_output.output = bi_output.output.to(torch::kCPU);
-    bi_output.state = bi_output.state.to(torch::kCPU);
-    reverse_output.output = reverse_output.output.to(torch::kCPU);
-    reverse_output.state = reverse_output.state.to(torch::kCPU);
+    bi_output = lstm_output_to_device(bi_output, torch::kCPU);
+    reverse_output = lstm_output_to_device(reverse_output, torch::kCPU);
   }
 
-  ASSERT_EQ(bi_output.output.size(0), reverse_output.output.size(0));
-  auto size = bi_output.output.size(0);
+  ASSERT_EQ(std::get<0>(bi_output).size(0), std::get<0>(reverse_output).size(0));
+  auto size = std::get<0>(bi_output).size(0);
   for (int i = 0; i < size; i++) {
-    ASSERT_EQ(bi_output.output[i][0][1].item<float>(),
-              reverse_output.output[size - 1 - i][0][0].item<float>());
+    ASSERT_EQ(std::get<0>(bi_output)[i][0][1].item<float>(),
+              std::get<0>(reverse_output)[size - 1 - i][0][0].item<float>());
   }
   // The hidden states of the reversed LSTM sits
   // in the odd indices in the first dimension.
-  ASSERT_EQ(bi_output.state[0][1][0][0].item<float>(),
-            reverse_output.state[0][0][0][0].item<float>());
-  ASSERT_EQ(bi_output.state[1][1][0][0].item<float>(),
-            reverse_output.state[1][0][0][0].item<float>());
+  ASSERT_EQ(std::get<0>(std::get<1>(bi_output))[1][0][0].item<float>(),
+            std::get<0>(std::get<1>(reverse_output))[0][0][0].item<float>());
+  ASSERT_EQ(std::get<1>(std::get<1>(bi_output))[1][0][0].item<float>(),
+            std::get<1>(std::get<1>(reverse_output))[0][0][0].item<float>());
 }
 
 TEST_F(RNNTest, BidirectionalLSTMReverseForward) {
@@ -363,19 +406,15 @@
 
 TEST_F(RNNTest, BidirectionalMultilayerGRU_CPU_vs_CUDA) {
   // Create two GRUs with the same options
-  auto opt = GRUOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
+  auto opt = GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
   GRU gru_cpu {opt};
   GRU gru_cuda {opt};
 
   // Copy weights and biases from CPU GRU to CUDA GRU
   {
     at::NoGradGuard guard;
-    const auto num_directions = gru_cpu->options.bidirectional() ? 2 : 1;
-    for (int64_t layer = 0; layer < gru_cpu->options.layers(); layer++) {
-      for (auto direction = 0; direction < num_directions; direction++) {
-        const auto layer_idx = (layer * num_directions) + direction;
-        copyParameters(gru_cuda, layer_idx, gru_cpu, layer_idx);
-      }
+    for (const auto& param : gru_cpu->named_parameters(/*recurse=*/false)) {
+      gru_cuda->named_parameters()[param.key()].copy_(gru_cpu->named_parameters()[param.key()]);
     }
   }
 
@@ -397,20 +436,19 @@
   auto output_cpu = gru_cpu->forward(input_cpu);
   auto output_cuda = gru_cuda->forward(input_cuda);
 
-  output_cpu.output = output_cpu.output.to(torch::kCPU);
-  output_cpu.state = output_cpu.state.to(torch::kCPU);
+  output_cpu = gru_output_to_device(output_cpu, torch::kCPU);
 
   // Assert that the output and state are equal on CPU and CUDA
-  ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
-  for (int i = 0; i < output_cpu.output.dim(); i++) {
-    ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
+  ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
+  for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
+    ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
   }
-  for (int i = 0; i < output_cpu.output.size(0); i++) {
-    for (int j = 0; j < output_cpu.output.size(1); j++) {
-      for (int k = 0; k < output_cpu.output.size(2); k++) {
+  for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
+    for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
+      for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
         ASSERT_NEAR(
-          output_cpu.output[i][j][k].item<float>(),
-          output_cuda.output[i][j][k].item<float>(), 1e-5);
+          std::get<0>(output_cpu)[i][j][k].item<float>(),
+          std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
       }
     }
   }
@@ -418,19 +456,15 @@
 
 TEST_F(RNNTest, BidirectionalMultilayerLSTM_CPU_vs_CUDA) {
   // Create two LSTMs with the same options
-  auto opt = LSTMOptions(2, 4).layers(3).batch_first(false).bidirectional(true);
+  auto opt = LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true);
   LSTM lstm_cpu {opt};
   LSTM lstm_cuda {opt};
 
   // Copy weights and biases from CPU LSTM to CUDA LSTM
   {
     at::NoGradGuard guard;
-    const auto num_directions = lstm_cpu->options.bidirectional() ? 2 : 1;
-    for (int64_t layer = 0; layer < lstm_cpu->options.layers(); layer++) {
-      for (auto direction = 0; direction < num_directions; direction++) {
-        const auto layer_idx = (layer * num_directions) + direction;
-        copyParameters(lstm_cuda, layer_idx, lstm_cpu, layer_idx);
-      }
+    for (const auto& param : lstm_cpu->named_parameters(/*recurse=*/false)) {
+      lstm_cuda->named_parameters()[param.key()].copy_(lstm_cpu->named_parameters()[param.key()]);
     }
   }
 
@@ -451,21 +485,68 @@
   auto output_cpu = lstm_cpu->forward(input_cpu);
   auto output_cuda = lstm_cuda->forward(input_cuda);
 
-  output_cpu.output = output_cpu.output.to(torch::kCPU);
-  output_cpu.state = output_cpu.state.to(torch::kCPU);
+  output_cpu = lstm_output_to_device(output_cpu, torch::kCPU);
 
   // Assert that the output and state are equal on CPU and CUDA
-  ASSERT_EQ(output_cpu.output.dim(), output_cuda.output.dim());
-  for (int i = 0; i < output_cpu.output.dim(); i++) {
-    ASSERT_EQ(output_cpu.output.size(i), output_cuda.output.size(i));
+  ASSERT_EQ(std::get<0>(output_cpu).dim(), std::get<0>(output_cuda).dim());
+  for (int i = 0; i < std::get<0>(output_cpu).dim(); i++) {
+    ASSERT_EQ(std::get<0>(output_cpu).size(i), std::get<0>(output_cuda).size(i));
   }
-  for (int i = 0; i < output_cpu.output.size(0); i++) {
-    for (int j = 0; j < output_cpu.output.size(1); j++) {
-      for (int k = 0; k < output_cpu.output.size(2); k++) {
+  for (int i = 0; i < std::get<0>(output_cpu).size(0); i++) {
+    for (int j = 0; j < std::get<0>(output_cpu).size(1); j++) {
+      for (int k = 0; k < std::get<0>(output_cpu).size(2); k++) {
         ASSERT_NEAR(
-          output_cpu.output[i][j][k].item<float>(),
-          output_cuda.output[i][j][k].item<float>(), 1e-5);
+          std::get<0>(output_cpu)[i][j][k].item<float>(),
+          std::get<0>(output_cuda)[i][j][k].item<float>(), 1e-5);
       }
     }
   }
 }
+
+TEST_F(RNNTest, UsePackedSequenceAsInput) {
+  {
+    torch::manual_seed(0);
+    auto m = RNN(2, 3);
+    torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
+    auto rnn_output = m->forward_with_packed_input(packed_input);
+    auto expected_output = torch::tensor(
+    {{-0.0645, -0.7274,  0.4531},
+     {-0.3970, -0.6950,  0.6009},
+     {-0.3877, -0.7310,  0.6806}});
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+
+    // Test passing optional argument to `RNN::forward_with_packed_input`
+    rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+  }
+  {
+    torch::manual_seed(0);
+    auto m = LSTM(2, 3);
+    torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
+    auto rnn_output = m->forward_with_packed_input(packed_input);
+    auto expected_output = torch::tensor(
+    {{-0.2693, -0.1240,  0.0744},
+     {-0.3889, -0.1919,  0.1183},
+     {-0.4425, -0.2314,  0.1386}});
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+
+    // Test passing optional argument to `LSTM::forward_with_packed_input`
+    rnn_output = m->forward_with_packed_input(packed_input, torch::nullopt);
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+  }
+  {
+    torch::manual_seed(0);
+    auto m = GRU(2, 3);
+    torch::nn::utils::rnn::PackedSequence packed_input = torch::nn::utils::rnn::pack_sequence({torch::ones({3, 2})});
+    auto rnn_output = m->forward_with_packed_input(packed_input);
+    auto expected_output = torch::tensor(
+    {{-0.1134,  0.0467,  0.2336},
+     {-0.1189,  0.0502,  0.2960},
+     {-0.1138,  0.0484,  0.3110}});
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+
+    // Test passing optional argument to `GRU::forward_with_packed_input`
+    rnn_output = m->forward_with_packed_input(packed_input, torch::Tensor());
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output).data(), expected_output, 1e-05, 2e-04));
+  }
+}
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index 9107781..2bd3fd9 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -410,7 +410,7 @@
       "  (2): torch::nn::Dropout(p=0.5, inplace=false)\n"
       "  (3): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
       "  (4): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
-      "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+      "  (5): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
       ")");
 
   Sequential sequential_named({
@@ -429,7 +429,7 @@
       "  (dropout): torch::nn::Dropout(p=0.5, inplace=false)\n"
       "  (batchnorm2d): torch::nn::BatchNorm2d(5, eps=1e-05, momentum=0.1, affine=true, track_running_stats=true)\n"
       "  (embedding): torch::nn::Embedding(num_embeddings=4, embedding_dim=10)\n"
-      "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, layers=1, dropout=0)\n"
+      "  (lstm): torch::nn::LSTM(input_size=4, hidden_size=5, num_layers=1, bias=true, batch_first=false, dropout=0, bidirectional=false)\n"
       ")");
 }
 
@@ -598,21 +598,21 @@
     torch::manual_seed(0);
     Sequential sequential(Identity(), RNN(2, 3));
     auto x = torch::ones({2, 3, 2});
-    auto rnn_output = sequential->forward<RNNOutput>(x);
+    auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
     auto expected_output = torch::tensor(
-    {{{0.0000, 0.0000,  0.4886},
-      {0.0000, 0.0000,  0.4886},
-      {0.0000, 0.0000,  0.4886}},
-     {{0.0000, 0.0000,  0.3723},
-      {0.0000, 0.0000,  0.3723},
-      {0.0000, 0.0000,  0.3723}}});
-    ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
+    {{{-0.0645, -0.7274,  0.4531},
+      {-0.0645, -0.7274,  0.4531},
+      {-0.0645, -0.7274,  0.4531}},
+     {{-0.3970, -0.6950,  0.6009},
+      {-0.3970, -0.6950,  0.6009},
+      {-0.3970, -0.6950,  0.6009}}});
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
   }
   {
     torch::manual_seed(0);
     Sequential sequential(Identity(), LSTM(2, 3));
     auto x = torch::ones({2, 3, 2});
-    auto rnn_output = sequential->forward<RNNOutput>(x);
+    auto rnn_output = sequential->forward<std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>>(x);
     auto expected_output = torch::tensor(
     {{{-0.2693, -0.1240,  0.0744},
       {-0.2693, -0.1240,  0.0744},
@@ -620,13 +620,13 @@
      {{-0.3889, -0.1919,  0.1183},
       {-0.3889, -0.1919,  0.1183},
       {-0.3889, -0.1919,  0.1183}}});
-    ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
   }
   {
     torch::manual_seed(0);
     Sequential sequential(Identity(), GRU(2, 3));
     auto x = torch::ones({2, 3, 2});
-    auto rnn_output = sequential->forward<RNNOutput>(x);
+    auto rnn_output = sequential->forward<std::tuple<torch::Tensor, torch::Tensor>>(x);
     auto expected_output = torch::tensor(
     {{{-0.1134,  0.0467,  0.2336},
       {-0.1134,  0.0467,  0.2336},
@@ -634,7 +634,7 @@
      {{-0.1189,  0.0502,  0.2960},
       {-0.1189,  0.0502,  0.2960},
       {-0.1189,  0.0502,  0.2960}}});
-    ASSERT_TRUE(torch::allclose(rnn_output.output, expected_output, 1e-05, 2e-04));
+    ASSERT_TRUE(torch::allclose(std::get<0>(rnn_output), expected_output, 1e-05, 2e-04));
   }
   {
     torch::manual_seed(0);
diff --git a/test/cpp_api_parity/parity-tracker.md b/test/cpp_api_parity/parity-tracker.md
index e9940d8..0853b96 100644
--- a/test/cpp_api_parity/parity-tracker.md
+++ b/test/cpp_api_parity/parity-tracker.md
@@ -81,9 +81,9 @@
 torch::nn::LayerNorm|Yes|No
 torch::nn::LocalResponseNorm|Yes|No
 torch::nn::CrossMapLRN2d|Yes|No
-torch::nn::RNN|No|No
-torch::nn::LSTM|No|No
-torch::nn::GRU|No|No
+torch::nn::RNN|Yes|No
+torch::nn::LSTM|Yes|No
+torch::nn::GRU|Yes|No
 torch::nn::RNNCell|Yes|No
 torch::nn::LSTMCell|Yes|No
 torch::nn::GRUCell|Yes|No
diff --git a/torch/csrc/api/include/torch/enum.h b/torch/csrc/api/include/torch/enum.h
index fdd1c3c..2f28530 100644
--- a/torch/csrc/api/include/torch/enum.h
+++ b/torch/csrc/api/include/torch/enum.h
@@ -122,6 +122,10 @@
 TORCH_ENUM_DECLARE(Zeros)
 TORCH_ENUM_DECLARE(Border)
 TORCH_ENUM_DECLARE(Reflection)
+TORCH_ENUM_DECLARE(RNN_TANH)
+TORCH_ENUM_DECLARE(RNN_RELU)
+TORCH_ENUM_DECLARE(LSTM)
+TORCH_ENUM_DECLARE(GRU)
 
 namespace torch {
 namespace enumtype {
@@ -157,6 +161,10 @@
   TORCH_ENUM_PRETTY_PRINT(Zeros)
   TORCH_ENUM_PRETTY_PRINT(Border)
   TORCH_ENUM_PRETTY_PRINT(Reflection)
+  TORCH_ENUM_PRETTY_PRINT(RNN_TANH)
+  TORCH_ENUM_PRETTY_PRINT(RNN_RELU)
+  TORCH_ENUM_PRETTY_PRINT(LSTM)
+  TORCH_ENUM_PRETTY_PRINT(GRU)
 };
 
 template <typename V>
diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h
index d065faf..d324453 100644
--- a/torch/csrc/api/include/torch/nn/modules/rnn.h
+++ b/torch/csrc/api/include/torch/nn/modules/rnn.h
@@ -5,6 +5,7 @@
 #include <torch/nn/modules/common.h>
 #include <torch/nn/modules/dropout.h>
 #include <torch/nn/pimpl.h>
+#include <torch/nn/utils/rnn.h>
 #include <torch/types.h>
 
 #include <ATen/ATen.h>
@@ -15,36 +16,23 @@
 #include <memory>
 #include <vector>
 
+using namespace torch::nn::utils::rnn;
+
 namespace torch {
 namespace nn {
 
-/// The output of a single invocation of an RNN module's `forward()` method.
-struct TORCH_API RNNOutput {
-  /// The result of applying the specific RNN algorithm
-  /// to the input tensor and input state.
-  Tensor output;
-  /// The new, updated state that can be fed into the RNN
-  /// in the next forward step.
-  Tensor state;
-};
-
 namespace detail {
 /// Base class for all RNN implementations (intended for code sharing).
 template <typename Derived>
 class TORCH_API RNNImplBase : public torch::nn::Cloneable<Derived> {
  public:
-  /// These must line up with the CUDNN mode codes:
-  /// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
-  enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
-
-  explicit RNNImplBase(
-      const RNNOptionsBase& options_,
-      optional<CuDNNMode> cudnn_mode = nullopt,
-      int64_t number_of_gates = 1);
+  explicit RNNImplBase(const RNNOptionsBase& options_);
 
   /// Initializes the parameters of the RNN module.
   void reset() override;
 
+  void reset_parameters();
+
   /// Overrides `nn::Module::to()` to call `flatten_parameters()` after the
   /// original operation.
   void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)
@@ -65,52 +53,32 @@
   /// called once upon construction, inside `reset()`.
   void flatten_parameters();
 
-  /// The RNN's options.
-  RNNOptionsBase options;
+  std::vector<Tensor> all_weights() const;
 
-  /// The weights for `input x hidden` gates.
-  std::vector<Tensor> w_ih;
-  /// The weights for `hidden x hidden` gates.
-  std::vector<Tensor> w_hh;
-  /// The biases for `input x hidden` gates.
-  std::vector<Tensor> b_ih;
-  /// The biases for `hidden x hidden` gates.
-  std::vector<Tensor> b_hh;
+  /// The RNN's options.
+  RNNOptionsBase options_base;
 
  protected:
-  /// The function signature of `rnn_relu`, `rnn_tanh` and `gru`.
-  using RNNFunctionSignature = std::tuple<Tensor, Tensor>(
-      /*input=*/const Tensor&,
-      /*state=*/const Tensor&,
-      /*params=*/TensorList,
-      /*has_biases=*/bool,
-      /*layers=*/int64_t,
-      /*dropout=*/double,
-      /*train=*/bool,
-      /*bidirectional=*/bool,
-      /*batch_first=*/bool);
+  // Resets flat_weights_
+  // Note: be v. careful before removing this, as 3rd party device types
+  // likely rely on this behavior to properly .to() modules like LSTM.
+  void reset_flat_weights();
 
-  /// A generic `forward()` used for RNN and GRU (but not LSTM!). Takes the ATen
-  /// RNN function as first argument.
-  RNNOutput generic_forward(
-      std::function<RNNFunctionSignature> function,
-      const Tensor& input,
-      Tensor state);
+  void check_input(const Tensor& input, const Tensor& batch_sizes) const;
 
-  /// Returns a flat vector of all weights, with layer weights following each
-  /// other sequentially in (w_ih, w_hh, b_ih, b_hh) order.
-  std::vector<Tensor> flat_weights() const;
+  std::tuple<int64_t, int64_t, int64_t> get_expected_hidden_size(const Tensor& input, const Tensor& batch_sizes) const;
 
-  /// Very simple check if any of the parameters (weights, biases) are the same.
-  bool any_parameters_alias() const;
+  void check_hidden_size(
+    const Tensor& hx,
+    std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
+    std::string msg = "Expected hidden size {1}, got {2}") const;
 
-  /// The number of gate weights/biases required by the RNN subclass.
-  int64_t number_of_gates_;
+  void check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const;
 
-  /// The cuDNN RNN mode, if this RNN subclass has any.
-  optional<CuDNNMode> cudnn_mode_;
+  Tensor permute_hidden(Tensor hx, const Tensor& permutation) const;
 
-  /// The cached result of the latest `flat_weights()` call.
+  std::vector<std::string> flat_weights_names_;
+  std::vector<std::vector<std::string>> all_weights_;
   std::vector<Tensor> flat_weights_;
 };
 } // namespace detail
@@ -118,84 +86,139 @@
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 /// A multi-layer Elman RNN module with Tanh or ReLU activation.
-/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn about the
-/// exact behavior of this module.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.RNN to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::RNNOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
+/// ```
 class TORCH_API RNNImpl : public detail::RNNImplBase<RNNImpl> {
  public:
   RNNImpl(int64_t input_size, int64_t hidden_size)
       : RNNImpl(RNNOptions(input_size, hidden_size)) {}
   explicit RNNImpl(const RNNOptions& options_);
 
-  /// Pretty prints the `RNN` module into the given `stream`.
-  void pretty_print(std::ostream& stream) const override;
-
-  /// Applies the `RNN` module to an input sequence and input state.
-  /// The `input` should follow a `(sequence, batch, features)` layout unless
-  /// `batch_first` is true, in which case the layout should be `(batch,
-  /// sequence, features)`.
-  RNNOutput forward(const Tensor& input, Tensor state = {});
+  std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
  protected:
   FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
+
  public:
+  std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
+
   RNNOptions options;
+
+ protected:
+  std::tuple<Tensor, Tensor> forward_helper(
+    const Tensor& input,
+    const Tensor& batch_sizes,
+    const Tensor& sorted_indices,
+    int64_t max_batch_size,
+    Tensor hx);
 };
 
 /// A `ModuleHolder` subclass for `RNNImpl`.
-/// See the documentation for `RNNImpl` class to learn what methods it provides,
-/// or the documentation for `ModuleHolder` to learn about PyTorch's module
-/// storage semantics.
+/// See the documentation for `RNNImpl` class to learn what methods it
+/// provides, and examples of how to use `RNN` with `torch::nn::RNNOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
 TORCH_MODULE(RNN);
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 /// A multi-layer long-short-term-memory (LSTM) module.
-/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn about the
-/// exact behavior of this module.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.LSTM to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::LSTMOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
 class TORCH_API LSTMImpl : public detail::RNNImplBase<LSTMImpl> {
  public:
   LSTMImpl(int64_t input_size, int64_t hidden_size)
       : LSTMImpl(LSTMOptions(input_size, hidden_size)) {}
   explicit LSTMImpl(const LSTMOptions& options_);
 
-  /// Applies the `LSTM` module to an input sequence and input state.
-  /// The `input` should follow a `(sequence, batch, features)` layout unless
-  /// `batch_first` is true, in which case the layout should be `(batch,
-  /// sequence, features)`.
-  RNNOutput forward(const Tensor& input, Tensor state = {});
+  std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward(
+    const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
  protected:
-  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
+  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::optional<std::tuple<Tensor, Tensor>>())})
+
+ public:
+  std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> forward_with_packed_input(
+    const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt = {});
+
+  LSTMOptions options;
+
+ protected:
+  void check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const;
+
+  std::tuple<Tensor, Tensor> permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const;
+
+  std::tuple<Tensor, std::tuple<Tensor, Tensor>> forward_helper(
+    const Tensor& input,
+    const Tensor& batch_sizes,
+    const Tensor& sorted_indices,
+    int64_t max_batch_size,
+    torch::optional<std::tuple<Tensor, Tensor>> hx_opt);
 };
 
 /// A `ModuleHolder` subclass for `LSTMImpl`.
 /// See the documentation for `LSTMImpl` class to learn what methods it
-/// provides, or the documentation for `ModuleHolder` to learn about PyTorch's
+/// provides, and examples of how to use `LSTM` with `torch::nn::LSTMOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
 /// module storage semantics.
 TORCH_MODULE(LSTM);
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 /// A multi-layer gated recurrent unit (GRU) module.
-/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn about the
-/// exact behavior of this module.
+/// See https://pytorch.org/docs/master/nn.html#torch.nn.GRU to learn
+/// about the exact behavior of this module.
+///
+/// See the documentation for `torch::nn::GRUOptions` class to learn what
+/// constructor arguments are supported for this module.
+///
+/// Example:
+/// ```
+/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
 class TORCH_API GRUImpl : public detail::RNNImplBase<GRUImpl> {
  public:
   GRUImpl(int64_t input_size, int64_t hidden_size)
       : GRUImpl(GRUOptions(input_size, hidden_size)) {}
   explicit GRUImpl(const GRUOptions& options_);
 
-  /// Applies the `GRU` module to an input sequence and input state.
-  /// The `input` should follow a `(sequence, batch, features)` layout unless
-  /// `batch_first` is true, in which case the layout should be `(batch,
-  /// sequence, features)`.
-  RNNOutput forward(const Tensor& input, Tensor state = {});
+  std::tuple<Tensor, Tensor> forward(const Tensor& input, Tensor hx = {});
  protected:
-  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())})
+  FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(torch::Tensor())})
+
+ public:
+  std::tuple<PackedSequence, Tensor> forward_with_packed_input(const PackedSequence& packed_input, Tensor hx = {});
+
+  GRUOptions options;
+
+ protected:
+  std::tuple<Tensor, Tensor> forward_helper(
+    const Tensor& input,
+    const Tensor& batch_sizes,
+    const Tensor& sorted_indices,
+    int64_t max_batch_size,
+    Tensor hx);
 };
 
 /// A `ModuleHolder` subclass for `GRUImpl`.
-/// See the documentation for `GRUImpl` class to learn what methods it provides,
-/// or the documentation for `ModuleHolder` to learn about PyTorch's module
-/// storage semantics.
+/// See the documentation for `GRUImpl` class to learn what methods it
+/// provides, and examples of how to use `GRU` with `torch::nn::GRUOptions`.
+/// See the documentation for `ModuleHolder` to learn about PyTorch's
+/// module storage semantics.
 TORCH_MODULE(GRU);
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/torch/csrc/api/include/torch/nn/options/rnn.h b/torch/csrc/api/include/torch/nn/options/rnn.h
index f500a90..ae37693 100644
--- a/torch/csrc/api/include/torch/nn/options/rnn.h
+++ b/torch/csrc/api/include/torch/nn/options/rnn.h
@@ -10,65 +10,137 @@
 
 namespace detail {
 
-/// Common options for LSTM and GRU modules.
+/// Common options for RNN, LSTM and GRU modules.
 struct TORCH_API RNNOptionsBase {
-  RNNOptionsBase(int64_t input_size, int64_t hidden_size);
-  virtual ~RNNOptionsBase() = default;
+  typedef c10::variant<
+    enumtype::kLSTM,
+    enumtype::kGRU,
+    enumtype::kRNN_TANH,
+    enumtype::kRNN_RELU> rnn_options_base_mode_t;
+
+  RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size);
+
+  TORCH_ARG(rnn_options_base_mode_t, mode);
   /// The number of features of a single sample in the input sequence `x`.
   TORCH_ARG(int64_t, input_size);
   /// The number of features in the hidden state `h`.
   TORCH_ARG(int64_t, hidden_size);
   /// The number of recurrent layers (cells) to use.
-  TORCH_ARG(int64_t, layers) = 1;
+  TORCH_ARG(int64_t, num_layers) = 1;
   /// Whether a bias term should be added to all linear operations.
-  TORCH_ARG(bool, with_bias) = true;
+  TORCH_ARG(bool, bias) = true;
+  /// If true, the input sequence should be provided as `(batch, sequence,
+  /// features)`. If false (default), the expected layout is `(sequence, batch,
+  /// features)`.
+  TORCH_ARG(bool, batch_first) = false;
   /// If non-zero, adds dropout with the given probability to the output of each
   /// RNN layer, except the final layer.
   TORCH_ARG(double, dropout) = 0.0;
   /// Whether to make the RNN bidirectional.
   TORCH_ARG(bool, bidirectional) = false;
-  /// If true, the input sequence should be provided as `(batch, sequence,
-  /// features)`. If false (default), the expected layout is `(sequence, batch,
-  /// features)`.
-  TORCH_ARG(bool, batch_first) = false;
 };
 
 } // namespace detail
 
-enum class RNNActivation : uint32_t {ReLU, Tanh};
-
-/// Options for RNN modules.
+/// Options for the `RNN` module.
+///
+/// Example:
+/// ```
+/// RNN model(RNNOptions(128, 64).num_layers(3).dropout(0.2).nonlinearity(torch::kTanh));
+/// ```
 struct TORCH_API RNNOptions {
+  typedef c10::variant<enumtype::kTanh, enumtype::kReLU> nonlinearity_t;
+
   RNNOptions(int64_t input_size, int64_t hidden_size);
 
-  /// Sets the activation after linear operations to `tanh`.
-  RNNOptions& tanh();
-  /// Sets the activation after linear operations to `relu`.
-  RNNOptions& relu();
-
-  /// The number of features of a single sample in the input sequence `x`.
+  /// The number of expected features in the input `x`
   TORCH_ARG(int64_t, input_size);
-  /// The number of features in the hidden state `h`.
+  /// The number of features in the hidden state `h`
   TORCH_ARG(int64_t, hidden_size);
-  /// The number of recurrent layers (cells) to use.
-  TORCH_ARG(int64_t, layers) = 1;
-  /// Whether a bias term should be added to all linear operations.
-  TORCH_ARG(bool, with_bias) = true;
-  /// If non-zero, adds dropout with the given probability to the output of each
-  /// RNN layer, except the final layer.
-  TORCH_ARG(double, dropout) = 0.0;
-  /// Whether to make the RNN bidirectional.
-  TORCH_ARG(bool, bidirectional) = false;
-  /// If true, the input sequence should be provided as `(batch, sequence,
-  /// features)`. If false (default), the expected layout is `(sequence, batch,
-  /// features)`.
+  /// Number of recurrent layers. E.g., setting ``num_layers=2``
+  /// would mean stacking two RNNs together to form a `stacked RNN`,
+  /// with the second RNN taking in outputs of the first RNN and
+  /// computing the final results. Default: 1
+  TORCH_ARG(int64_t, num_layers) = 1;
+  /// The non-linearity to use. Can be either ``torch::kTanh`` or ``torch::kReLU``. Default: ``torch::kTanh``
+  TORCH_ARG(nonlinearity_t, nonlinearity) = torch::kTanh;
+  /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
+  /// Default: ``true``
+  TORCH_ARG(bool, bias) = true;
+  /// If ``true``, then the input and output tensors are provided
+  /// as `(batch, seq, feature)`. Default: ``false``
   TORCH_ARG(bool, batch_first) = false;
-  /// The activation to use after linear operations.
-  TORCH_ARG(RNNActivation, activation) = RNNActivation::ReLU;
+  /// If non-zero, introduces a `Dropout` layer on the outputs of each
+  /// RNN layer except the last layer, with dropout probability equal to
+  /// `dropout`. Default: 0
+  TORCH_ARG(double, dropout) = 0.0;
+  /// If ``true``, becomes a bidirectional RNN. Default: ``false``
+  TORCH_ARG(bool, bidirectional) = false;
 };
 
-using LSTMOptions = detail::RNNOptionsBase;
-using GRUOptions = detail::RNNOptionsBase;
+/// Options for the `LSTM` module.
+///
+/// Example:
+/// ```
+/// LSTM model(LSTMOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
+struct TORCH_API LSTMOptions {
+  LSTMOptions(int64_t input_size, int64_t hidden_size);
+
+  /// The number of expected features in the input `x`
+  TORCH_ARG(int64_t, input_size);
+  /// The number of features in the hidden state `h`
+  TORCH_ARG(int64_t, hidden_size);
+  /// Number of recurrent layers. E.g., setting ``num_layers=2``
+  /// would mean stacking two LSTMs together to form a `stacked LSTM`,
+  /// with the second LSTM taking in outputs of the first LSTM and
+  /// computing the final results. Default: 1
+  TORCH_ARG(int64_t, num_layers) = 1;
+  /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
+  /// Default: ``true``
+  TORCH_ARG(bool, bias) = true;
+  /// If ``true``, then the input and output tensors are provided
+  /// as (batch, seq, feature). Default: ``false``
+  TORCH_ARG(bool, batch_first) = false;
+  /// If non-zero, introduces a `Dropout` layer on the outputs of each
+  /// LSTM layer except the last layer, with dropout probability equal to
+  /// `dropout`. Default: 0
+  TORCH_ARG(double, dropout) = 0.0;
+  /// If ``true``, becomes a bidirectional LSTM. Default: ``false``
+  TORCH_ARG(bool, bidirectional) = false;
+};
+
+/// Options for the `GRU` module.
+///
+/// Example:
+/// ```
+/// GRU model(GRUOptions(2, 4).num_layers(3).batch_first(false).bidirectional(true));
+/// ```
+struct TORCH_API GRUOptions {
+  GRUOptions(int64_t input_size, int64_t hidden_size);
+
+  /// The number of expected features in the input `x`
+  TORCH_ARG(int64_t, input_size);
+  /// The number of features in the hidden state `h`
+  TORCH_ARG(int64_t, hidden_size);
+  /// Number of recurrent layers. E.g., setting ``num_layers=2``
+  /// would mean stacking two GRUs together to form a `stacked GRU`,
+  /// with the second GRU taking in outputs of the first GRU and
+  /// computing the final results. Default: 1
+  TORCH_ARG(int64_t, num_layers) = 1;
+  /// If ``false``, then the layer does not use bias weights `b_ih` and `b_hh`.
+  /// Default: ``true``
+  TORCH_ARG(bool, bias) = true;
+  /// If ``true``, then the input and output tensors are provided
+  /// as (batch, seq, feature). Default: ``false``
+  TORCH_ARG(bool, batch_first) = false;
+  /// If non-zero, introduces a `Dropout` layer on the outputs of each
+  /// GRU layer except the last layer, with dropout probability equal to
+  /// `dropout`. Default: 0
+  TORCH_ARG(double, dropout) = 0.0;
+  /// If ``true``, becomes a bidirectional GRU. Default: ``false``
+  TORCH_ARG(bool, bidirectional) = false;
+};
 
 namespace detail {
 
diff --git a/torch/csrc/api/src/enum.cpp b/torch/csrc/api/src/enum.cpp
index de15449..28bd25b 100644
--- a/torch/csrc/api/src/enum.cpp
+++ b/torch/csrc/api/src/enum.cpp
@@ -30,3 +30,7 @@
 TORCH_ENUM_DEFINE(Zeros)
 TORCH_ENUM_DEFINE(Border)
 TORCH_ENUM_DEFINE(Reflection)
+TORCH_ENUM_DEFINE(RNN_TANH)
+TORCH_ENUM_DEFINE(RNN_RELU)
+TORCH_ENUM_DEFINE(LSTM)
+TORCH_ENUM_DEFINE(GRU)
diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp
index 7f8aa39..0e54fda 100644
--- a/torch/csrc/api/src/nn/modules/rnn.cpp
+++ b/torch/csrc/api/src/nn/modules/rnn.cpp
@@ -1,6 +1,5 @@
 #include <torch/nn/modules/rnn.h>
 
-#include <torch/nn/modules/dropout.h>
 #include <torch/nn/init.h>
 #include <torch/types.h>
 #include <torch/utils.h>
@@ -19,65 +18,193 @@
 #include <utility>
 #include <vector>
 
+using namespace torch::nn::utils::rnn;
+
 namespace torch {
 namespace nn {
+
+/// These must line up with the CUDNN mode codes:
+/// https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnRNNMode_t
+enum class CuDNNMode { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
+
+CuDNNMode get_cudnn_mode_for_rnn(detail::RNNOptionsBase::rnn_options_base_mode_t mode) {
+  if (c10::get_if<enumtype::kRNN_RELU>(&mode)) {
+    return CuDNNMode::RNN_RELU;
+  } else if (c10::get_if<enumtype::kRNN_TANH>(&mode)) {
+    return CuDNNMode::RNN_TANH;
+  } else if (c10::get_if<enumtype::kLSTM>(&mode)) {
+    return CuDNNMode::LSTM;
+  } else if (c10::get_if<enumtype::kGRU>(&mode)) {
+    return CuDNNMode::GRU;
+  } else {
+    TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(mode));
+  }
+}
+
+Tensor apply_permutation(const Tensor& tensor, const Tensor& permutation, int64_t dim = 1) {
+  return tensor.index_select(dim, permutation);
+}
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 namespace detail {
 template <typename Derived>
-RNNImplBase<Derived>::RNNImplBase(
-    const RNNOptionsBase& options_,
-    optional<CuDNNMode> cudnn_mode,
-    int64_t number_of_gates)
-    : options(options_),
-      number_of_gates_(number_of_gates),
-      cudnn_mode_(std::move(cudnn_mode)) {
+RNNImplBase<Derived>::RNNImplBase(const RNNOptionsBase& options_)
+  : options_base(options_) {
   reset();
 }
 
 template <typename Derived>
 void RNNImplBase<Derived>::reset() {
-  const auto num_directions = options.bidirectional() ? 2 : 1;
+  const int64_t num_directions = options_base.bidirectional() ? 2 : 1;
 
-  w_ih.resize(options.layers() * num_directions);
-  w_hh.resize(options.layers() * num_directions);
-  b_ih.resize(options.layers() * num_directions);
-  b_hh.resize(options.layers() * num_directions);
+  TORCH_CHECK(
+    0 <= options_base.dropout() && options_base.dropout() <= 1,
+    "dropout should be a number in range [0, 1] ",
+    "representing the probability of an element being ",
+    "zeroed");
 
-  const int64_t gate_size = options.hidden_size() * number_of_gates_;
+  if (options_base.dropout() > 0 && options_base.num_layers() == 1) {
+    TORCH_WARN(
+      "dropout option adds dropout after all but last ",
+      "recurrent layer, so non-zero dropout expects ",
+      "num_layers greater than 1, but got dropout=", options_base.dropout(), " and ",
+      "num_layers=", options_base.num_layers());
+  }
 
-  for (int64_t layer = 0; layer < options.layers(); ++layer) {
-    for (auto direction = 0; direction < num_directions; direction++) {
-      const auto layer_input_size = layer == 0 ? options.input_size() :
-        options.hidden_size() * num_directions;
-      const auto suffix = direction == 1 ? "_reverse" : "";
-      const auto layer_idx = (layer * num_directions) + direction;
-      w_ih[layer_idx] = this->register_parameter(
-          "weight_ih_l" + std::to_string(layer) + suffix,
-          torch::empty({gate_size, layer_input_size}));
-      w_hh[layer_idx] = this->register_parameter(
-          "weight_hh_l" + std::to_string(layer) + suffix,
-          torch::empty({gate_size, options.hidden_size()}));
+  int64_t gate_size = 0;
+  if (c10::get_if<enumtype::kLSTM>(&options_base.mode())) {
+    gate_size = 4 * options_base.hidden_size();
+  } else if (c10::get_if<enumtype::kGRU>(&options_base.mode())) {
+    gate_size = 3 * options_base.hidden_size();
+  } else if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
+    gate_size = options_base.hidden_size();
+  } else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
+    gate_size = options_base.hidden_size();
+  } else {
+    TORCH_CHECK(false, "Unrecognized RNN mode: " + torch::enumtype::get_enum_name(options_base.mode()));
+  }
 
-      if (options.with_bias()) {
-        b_ih[layer_idx] = this->register_parameter(
-          "bias_ih_l" + std::to_string(layer) + suffix,
-          torch::empty({gate_size}));
-        b_hh[layer_idx] = this->register_parameter(
-          "bias_hh_l" + std::to_string(layer) + suffix,
-          torch::empty({gate_size}));
+  flat_weights_names_ = {};
+  all_weights_ = {};
+
+  for (int64_t layer = 0; layer < options_base.num_layers(); layer++) {
+    for (int64_t direction = 0; direction < num_directions; direction++) {
+      int64_t layer_input_size = layer == 0 ? options_base.input_size() : options_base.hidden_size() * num_directions;
+
+      auto w_ih = torch::empty({gate_size, layer_input_size});
+      auto w_hh = torch::empty({gate_size, options_base.hidden_size()});
+      auto b_ih = torch::empty({gate_size});
+      // Second bias vector included for CuDNN compatibility. Only one
+      // bias vector is needed in standard definition.
+      auto b_hh = torch::empty({gate_size});
+      std::vector<Tensor> layer_params = {w_ih, w_hh, b_ih, b_hh};
+
+      std::string suffix = direction == 1 ? "_reverse" : "";
+      std::vector<std::string> param_names = {"weight_ih_l{layer}{suffix}", "weight_hh_l{layer}{suffix}"};
+      if (options_base.bias()) {
+        param_names.emplace_back("bias_ih_l{layer}{suffix}");
+        param_names.emplace_back("bias_hh_l{layer}{suffix}");
       }
+      for (size_t i = 0; i < param_names.size(); i++) {  // NOLINT(modernize-loop-convert)
+        std::string x = std::regex_replace(param_names[i], std::regex("\\{layer\\}"), c10::str(layer));
+        x = std::regex_replace(x, std::regex("\\{suffix\\}"), c10::str(suffix));
+        param_names[i] = x;
+      }
+
+      for (size_t i = 0; i < param_names.size(); i++) {
+        auto name = param_names[i];
+        auto param = layer_params[i];
+        this->register_parameter(name, param);
+      }
+      flat_weights_names_.insert(flat_weights_names_.end(), param_names.begin(), param_names.end());
+      all_weights_.emplace_back(param_names);
     }
   }
 
+  flat_weights_ = {};
+  for (const auto& wn : flat_weights_names_) {
+    auto named_parameters = this->named_parameters(/*recurse=*/false);
+    if (named_parameters.contains(wn)) {
+      flat_weights_.emplace_back(named_parameters[wn]);
+    } else {
+      flat_weights_.emplace_back(Tensor());
+    }
+  }
+
+  this->flatten_parameters();
+  this->reset_parameters(); 
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::flatten_parameters() {
+  // Resets parameter data pointer so that they can use faster code paths.
+  //
+  // Right now, this works only if the module is on the GPU and cuDNN is enabled.
+  // Otherwise, it's a no-op.
+
+  // Short-circuits if flat_weights_ is only partially instantiated
+  if (flat_weights_.size() != flat_weights_names_.size()) {
+    return;
+  }
+
+  // Short-circuits if any tensor in self.flat_weights_ is not acceptable to cuDNN
+  // or the tensors in flat_weights_ are of different dtypes
+
+  auto first_fw = flat_weights_[0];
+  auto dtype = first_fw.dtype();
+  for (const auto& fw : flat_weights_) {
+    if (!(fw.dtype() == dtype) ||
+        !fw.is_cuda() ||
+        !torch::cudnn_is_acceptable(fw)) {
+      return;
+    }
+  }
+
+  // If any parameters alias, we fall back to the slower, copying code path. This is
+  // a sufficient check, because overlapping parameter buffers that don't completely
+  // alias would break the assumptions of the uniqueness check in
+  // Module::named_parameters().
+  std::unordered_set<void*> unique_data_ptrs;
+  for (const auto& p : flat_weights_) {
+    unique_data_ptrs.emplace(p.data_ptr());
+  }
+  if (unique_data_ptrs.size() != flat_weights_.size()) {
+    return;
+  }
+
   {
-    NoGradGuard no_grad;
-    const auto stdv = 1.0 / std::sqrt(options.hidden_size());
-    for (auto& p : this->parameters()) {
-      p.uniform_(-stdv, stdv);
+    torch::DeviceGuard device_guard(first_fw.device());
+
+    // Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
+    // an inplace operation on self.flat_weights_
+    {
+      torch::NoGradGuard no_grad;
+      if (torch::_use_cudnn_rnn_flatten_weight()) {
+        torch::_cudnn_rnn_flatten_weight(
+              flat_weights_,
+              options_base.bias() ? 4 : 2,
+              options_base.input_size(),
+              static_cast<int64_t>(get_cudnn_mode_for_rnn(options_base.mode())),
+              options_base.hidden_size(),
+              options_base.num_layers(),
+              options_base.batch_first(), 
+              options_base.bidirectional());
+      }
     }
   }
+}
 
-  flatten_parameters();
+template <typename Derived>
+void RNNImplBase<Derived>::reset_flat_weights() {
+  flat_weights_ = {};
+  for (const auto& wn : flat_weights_names_) {
+    auto named_parameters = this->named_parameters(/*recurse=*/false);
+    if (named_parameters.contains(wn)) {
+      flat_weights_.emplace_back(named_parameters[wn]);
+    } else {
+      flat_weights_.emplace_back(Tensor());
+    }
+  }
 }
 
 template <typename Derived>
@@ -86,126 +213,113 @@
     torch::Dtype dtype,
     bool non_blocking) {
   nn::Module::to(device, dtype, non_blocking);
+  reset_flat_weights();
   flatten_parameters();
 }
 
 template <typename Derived>
 void RNNImplBase<Derived>::to(torch::Dtype dtype, bool non_blocking) {
   nn::Module::to(dtype, non_blocking);
+  reset_flat_weights();
   flatten_parameters();
 }
 
 template <typename Derived>
 void RNNImplBase<Derived>::to(torch::Device device, bool non_blocking) {
   nn::Module::to(device, non_blocking);
-  const auto num_directions = options.bidirectional() ? 2 : 1;
-  for (int64_t layer = 0; layer < options.layers(); layer++) {
-    for (auto direction = 0; direction < num_directions; direction++) {
-      const auto layer_idx = (layer * num_directions) + direction;
-      w_ih[layer_idx] = w_ih[layer_idx].to(device, non_blocking);
-      w_hh[layer_idx] = w_hh[layer_idx].to(device, non_blocking);
-      if (options.with_bias()) {
-        b_ih[layer_idx] = b_ih[layer_idx].to(device, non_blocking);
-        b_hh[layer_idx] = b_hh[layer_idx].to(device, non_blocking);
-      }
-    }
-  }
+  reset_flat_weights();
   flatten_parameters();
 }
 
 template <typename Derived>
+void RNNImplBase<Derived>::reset_parameters() {
+  const double stdv = 1.0 / std::sqrt(options_base.hidden_size());
+  for (auto& weight : this->parameters()) {
+    init::uniform_(weight, -stdv, stdv);
+  }
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::check_input(const Tensor& input, const Tensor& batch_sizes) const {
+  int64_t expected_input_dim = batch_sizes.defined() ?  2 : 3;
+  TORCH_CHECK(
+    input.dim() == expected_input_dim,
+    "input must have ", expected_input_dim, " dimensions, got ", input.dim());
+  TORCH_CHECK(
+    options_base.input_size() == input.size(-1),
+    "input.size(-1) must be equal to input_size. Expected ", options_base.input_size(), ", got ", input.size(-1));
+}
+
+template <typename Derived>
+std::tuple<int64_t, int64_t, int64_t> RNNImplBase<Derived>::get_expected_hidden_size(
+  const Tensor& input, const Tensor& batch_sizes) const {
+  int64_t mini_batch = 0;
+  if (batch_sizes.defined()) {
+    mini_batch = batch_sizes[0].item<int64_t>();
+  } else {
+    mini_batch = options_base.batch_first() ? input.size(0) : input.size(1);
+  }    
+  int64_t num_directions = options_base.bidirectional() ? 2 : 1;
+  return std::make_tuple(options_base.num_layers() * num_directions, mini_batch, options_base.hidden_size());
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::check_hidden_size(
+    const Tensor& hx,
+    std::tuple<int64_t, int64_t, int64_t> expected_hidden_size,
+    std::string msg) const {
+  auto expected_hidden_size_vec = std::vector<int64_t>({
+    std::get<0>(expected_hidden_size),
+    std::get<1>(expected_hidden_size),
+    std::get<2>(expected_hidden_size),
+  });
+  if (hx.sizes() != expected_hidden_size_vec) {
+    msg = std::regex_replace(msg, std::regex("\\{1\\}"), c10::str(expected_hidden_size_vec));
+    msg = std::regex_replace(msg, std::regex("\\{2\\}"), c10::str(hx.sizes()));
+    TORCH_CHECK(false, msg);
+  }
+}
+
+template <typename Derived>
+void RNNImplBase<Derived>::check_forward_args(Tensor input, Tensor hidden, Tensor batch_sizes) const {
+  this->check_input(input, batch_sizes);
+  auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
+
+  this->check_hidden_size(hidden, expected_hidden_size);
+}
+
+template <typename Derived>
+Tensor RNNImplBase<Derived>::permute_hidden(Tensor hx, const Tensor& permutation) const {
+  if (!permutation.defined()) {
+    return hx;
+  }
+  return apply_permutation(hx, permutation);
+}
+
+template <typename Derived>
 void RNNImplBase<Derived>::pretty_print(std::ostream& stream) const {
   const std::string name = this->name();
   const std::string name_without_impl = name.substr(0, name.size() - 4);
-  stream << name_without_impl << "(input_size=" << options.input_size()
-         << ", hidden_size=" << options.hidden_size()
-         << ", layers=" << options.layers() << ", dropout=" << options.dropout()
+  stream << std::boolalpha << name_without_impl << "(input_size=" << options_base.input_size()
+         << ", hidden_size=" << options_base.hidden_size()
+         << ", num_layers=" << options_base.num_layers()
+         << ", bias=" << options_base.bias()
+         << ", batch_first=" << options_base.batch_first()
+         << ", dropout=" << options_base.dropout()
+         << ", bidirectional=" << options_base.bidirectional()
          << ")";
 }
 
 template <typename Derived>
-void RNNImplBase<Derived>::flatten_parameters() {
-  // Cache the flattened weight and bias vector.
-  flat_weights_ = flat_weights();
-
-  if (!cudnn_mode_ || !torch::cudnn_is_acceptable(w_ih.at(0))) {
-    return;
-  }
-
-  NoGradGuard no_grad;
-  if (torch::_use_cudnn_rnn_flatten_weight()) {
-    torch::_cudnn_rnn_flatten_weight(
-        flat_weights_,
-        /*weight_stride0=*/options.with_bias() ? 4 : 2,
-        options.input_size(),
-        static_cast<int64_t>(*cudnn_mode_),
-        options.hidden_size(),
-        options.layers(),
-        /*batch_first=*/options.batch_first(),
-        /*bidirectional=*/options.bidirectional());
-  }
-}
-
-template <typename Derived>
-RNNOutput RNNImplBase<Derived>::generic_forward(
-    std::function<RNNFunctionSignature> function,
-    const Tensor& input,
-    Tensor state) {
-  if (!state.defined()) {
-    // #layers, batch size, state size
-    const auto batch_size = input.size(options.batch_first() ? 0 : 1);
-    const auto num_directions = options.bidirectional() ? 2 : 1;
-    state = torch::zeros(
-      {options.layers() * num_directions, batch_size, options.hidden_size()},
-      input.options());
-  }
-  Tensor output, new_state;
-  std::tie(output, new_state) = function(
-      input,
-      std::move(state),
-      flat_weights_,
-      options.with_bias(),
-      options.layers(),
-      options.dropout(),
-      this->is_training(),
-      options.bidirectional(),
-      options.batch_first());
-  return {output, new_state};
-}
-
-template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::flat_weights() const {
-  // Organize all weights in a flat vector in the order
-  // (w_ih, w_hh, b_ih, b_hh), repeated for each layer (next to each other).
-  std::vector<Tensor> flat;
-  const auto num_directions = options.bidirectional() ? 2 : 1;
-  for (int64_t layer = 0; layer < options.layers(); layer++) {
-    for (auto direction = 0; direction < num_directions; direction++) {
-      const auto layer_idx = (layer * num_directions) + direction;
-      flat.push_back(w_ih[layer_idx]);
-      flat.push_back(w_hh[layer_idx]);
-      if (options.with_bias()) {
-        flat.push_back(b_ih[layer_idx]);
-        flat.push_back(b_hh[layer_idx]);
-      }
+std::vector<Tensor> RNNImplBase<Derived>::all_weights() const {
+  std::vector<Tensor> result = {};
+  auto named_parameters = this->named_parameters(/*recurse=*/false);
+  for (const auto& weights : all_weights_) {
+    for (const auto& weight : weights) {
+      result.emplace_back(named_parameters[weight]);
     }
   }
-  return flat;
-}
-
-template <typename Derived>
-bool RNNImplBase<Derived>::any_parameters_alias() const {
-  // If any parameters alias, we fall back to the slower, copying code path.
-  // This is a sufficient check, because overlapping parameter buffers that
-  // don't completely alias would break the assumptions of the uniqueness check
-  // in Module.named_parameters().
-  std::unordered_set<void*> unique_data_ptrs;
-  auto params = this->parameters();
-  unique_data_ptrs.reserve(params.size());
-  for (const auto& p : params) {
-    unique_data_ptrs.emplace(p.data_ptr());
-  }
-  return unique_data_ptrs.size() != params.size();
+  return result;
 }
 
 template class RNNImplBase<LSTMImpl>;
@@ -215,91 +329,275 @@
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-RNNImpl::RNNImpl(const RNNOptions& options_)
-    : detail::RNNImplBase<RNNImpl>(
-          detail::RNNOptionsBase(options_.input_size(), options_.hidden_size())
-              .layers(options_.layers())
-              .with_bias(options_.with_bias())
-              .dropout(options_.dropout())
-              .bidirectional(options_.bidirectional())
-              .batch_first(options_.batch_first()),
-          static_cast<CuDNNMode>(options_.activation())),
-      options(options_) {}
-
-void RNNImpl::pretty_print(std::ostream& stream) const {
-  stream << "torch::nn::RNN(input_size=" << options.input_size()
-         << ", hidden_size=" << options.hidden_size()
-         << ", layers=" << options.layers() << ", dropout=" << options.dropout()
-         << ", activation="
-         << (options.activation() == RNNActivation::Tanh ? "tanh" : "relu")
-         << ")";
+detail::RNNOptionsBase::rnn_options_base_mode_t compute_rnn_options_base_mode(
+  RNNOptions::nonlinearity_t nonlinearity) {
+  if (c10::get_if<enumtype::kTanh>(&nonlinearity)) {
+    return torch::kRNN_TANH;
+  } else if (c10::get_if<enumtype::kReLU>(&nonlinearity)) {
+    return torch::kRNN_RELU;
+  } else {
+    TORCH_CHECK(false, "Unknown nonlinearity ", torch::enumtype::get_enum_name(nonlinearity));
+  }
 }
 
-RNNOutput RNNImpl::forward(const Tensor& input, Tensor state) {
-  switch (options.activation()) {
-    case RNNActivation::ReLU:
-      return generic_forward(
-          static_cast<RNNFunctionSignature*>(&torch::rnn_relu),
-          input,
-          std::move(state));
-    case RNNActivation::Tanh:
-      return generic_forward(
-          static_cast<RNNFunctionSignature*>(&torch::rnn_tanh),
-          input,
-          std::move(state));
-    default:
-      AT_ERROR("Unhandled RNN activation function!");
+RNNImpl::RNNImpl(const RNNOptions& options_)
+    : detail::RNNImplBase<RNNImpl>(
+          detail::RNNOptionsBase(
+            compute_rnn_options_base_mode(options_.nonlinearity()),
+            options_.input_size(),
+            options_.hidden_size())
+              .num_layers(options_.num_layers())
+              .bias(options_.bias())
+              .batch_first(options_.batch_first())
+              .dropout(options_.dropout())
+              .bidirectional(options_.bidirectional())),
+      options(options_) {}
+
+std::tuple<Tensor, Tensor> RNNImpl::forward_helper(
+  const Tensor& input,
+  const Tensor& batch_sizes,
+  const Tensor& sorted_indices,
+  int64_t max_batch_size,
+  Tensor hx) {
+  if (!hx.defined()) {
+    int64_t num_directions = options_base.bidirectional() ? 2 : 1;
+    hx = torch::zeros({options_base.num_layers() * num_directions,
+                     max_batch_size, options_base.hidden_size()},
+                     torch::dtype(input.dtype()).device(input.device()));
+  } else {
+    // Each batch of the hidden state should match the input sequence that
+    // the user believes he/she is passing in.
+    hx = this->permute_hidden(hx, sorted_indices);
+  }    
+
+  this->check_forward_args(input, hx, batch_sizes);
+
+  std::tuple<Tensor, Tensor> result;
+  if (!batch_sizes.defined()) {
+    if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
+      result = torch::rnn_tanh(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
+                               options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
+    } else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
+      result = torch::rnn_relu(input, hx, flat_weights_, options_base.bias(), options_base.num_layers(),
+                               options_base.dropout(), this->is_training(), options_base.bidirectional(), options_base.batch_first());
+    } else {
+      TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
+    }
+  } else {
+    if (c10::get_if<enumtype::kRNN_TANH>(&options_base.mode())) {
+      result = torch::rnn_tanh(input, batch_sizes, hx, flat_weights_, options_base.bias(),
+                               options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
+    } else if (c10::get_if<enumtype::kRNN_RELU>(&options_base.mode())) {
+      result = torch::rnn_relu(input, batch_sizes, hx, flat_weights_, options_base.bias(),
+                               options_base.num_layers(), options_base.dropout(), this->is_training(), options_base.bidirectional());
+    } else {
+      TORCH_CHECK(false, "Unknown mode: ", torch::enumtype::get_enum_name(options_base.mode()));
+    }
   }
+  auto output = std::get<0>(result);
+  auto hidden = std::get<1>(result);
+
+  return std::make_tuple(output, hidden);
+}
+
+std::tuple<Tensor, Tensor> RNNImpl::forward(const Tensor& input, Tensor hx) {
+  auto batch_sizes = torch::Tensor();
+  auto max_batch_size = options_base.batch_first() ? input.size(0) : input.size(1);
+  auto sorted_indices = torch::Tensor();
+  auto unsorted_indices = torch::Tensor();
+
+  Tensor output, hidden;
+  std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+  return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
+}
+
+std::tuple<PackedSequence, Tensor> RNNImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
+  const auto& input = packed_input.data();
+  const auto& batch_sizes = packed_input.batch_sizes();
+  const auto& sorted_indices = packed_input.sorted_indices();
+  const auto& unsorted_indices = packed_input.unsorted_indices();
+  auto max_batch_size = batch_sizes[0].item<int64_t>();
+
+  Tensor output, hidden;
+  std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+  auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
+  return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
 }
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LSTM ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 LSTMImpl::LSTMImpl(const LSTMOptions& options_)
     : detail::RNNImplBase<LSTMImpl>(
-          options_,
-          CuDNNMode::LSTM,
-          /*number_of_gates=*/4) {}
+          detail::RNNOptionsBase(
+            torch::kLSTM,
+            options_.input_size(),
+            options_.hidden_size())
+              .num_layers(options_.num_layers())
+              .bias(options_.bias())
+              .batch_first(options_.batch_first())
+              .dropout(options_.dropout())
+              .bidirectional(options_.bidirectional())),
+      options(options_) {}
 
-RNNOutput LSTMImpl::forward(const Tensor& input, Tensor state) {
-  // It would be trickier to adapt the `generic_forward` for the LSTM because
-  // its output has a different dimensionality (3-tuple vs. 2-tuple), while we
-  // always return one state variable (stacking the hidden/cell state into one),
-  // which also makes the state variables going into the `generic_forward`, and
-  // the way we default-initialize the state when it is not passed, slightly
-  // different. So we just re-implement it specifically for the LSTM here.
-  if (!state.defined()) {
-    // 2 for hidden state and cell state, then #layers, batch size, state size
-    const auto batch_size = input.size(options.batch_first() ? 0 : 1);
-    const auto num_directions = options.bidirectional() ? 2 : 1;
-    state = torch::zeros(
-        {2, options.layers() * num_directions, batch_size, options.hidden_size()},
-        input.options());
+void LSTMImpl::check_forward_args(const Tensor& input, std::tuple<Tensor, Tensor> hidden, const Tensor& batch_sizes) const {
+  this->check_input(input, batch_sizes);
+  auto expected_hidden_size = this->get_expected_hidden_size(input, batch_sizes);
+
+  this->check_hidden_size(std::get<0>(hidden), expected_hidden_size,
+                          "Expected hidden[0] size {1}, got {2}");
+  this->check_hidden_size(std::get<1>(hidden), expected_hidden_size,
+                          "Expected hidden[1] size {1}, got {2}");
+}
+
+std::tuple<Tensor, Tensor> LSTMImpl::permute_hidden(std::tuple<Tensor, Tensor> hx, const Tensor& permutation) const {
+  if (!permutation.defined()) {
+    return hx;
   }
-  Tensor output, hidden_state, cell_state;
-  std::tie(output, hidden_state, cell_state) = torch::lstm(
-      input,
-      {state[0], state[1]},
-      flat_weights_,
-      options.with_bias(),
-      options.layers(),
-      options.dropout(),
-      this->is_training(),
-      options.bidirectional(),
-      options.batch_first());
-  return {output, torch::stack({hidden_state, cell_state})};
+  return std::make_tuple(
+    apply_permutation(std::get<0>(hx), permutation),
+    apply_permutation(std::get<1>(hx), permutation)
+  );
+}
+
+std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward_helper(
+  const Tensor& input,
+  const Tensor& batch_sizes,
+  const Tensor& sorted_indices,
+  int64_t max_batch_size,
+  torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
+
+  std::tuple<Tensor, Tensor> hx;
+  if (!hx_opt.has_value()) {
+    int64_t num_directions = options.bidirectional() ? 2 : 1;
+    auto zeros = torch::zeros({options.num_layers() * num_directions,
+                     max_batch_size, options.hidden_size()},
+                     torch::dtype(input.dtype()).device(input.device()));
+    hx = std::make_tuple(zeros, zeros);
+  } else {
+    hx = hx_opt.value();
+    // Each batch of the hidden state should match the input sequence that
+    // the user believes he/she is passing in.
+    hx = this->permute_hidden(hx, sorted_indices);
+  }    
+
+  this->check_forward_args(input, hx, batch_sizes);
+  std::tuple<Tensor, Tensor, Tensor> result;
+  if (!batch_sizes.defined()) {
+    result = torch::lstm(input, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(), options.num_layers(),
+                     options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
+  } else {
+    result = torch::lstm(input, batch_sizes, {std::get<0>(hx), std::get<1>(hx)}, flat_weights_, options.bias(),
+                     options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
+  }
+  auto output = std::get<0>(result);
+  auto hidden = std::make_tuple(std::get<1>(result), std::get<2>(result));
+
+  return std::make_tuple(output, hidden);
+}
+
+std::tuple<Tensor, std::tuple<Tensor, Tensor>> LSTMImpl::forward(
+  const Tensor& input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
+  auto batch_sizes = torch::Tensor();
+  auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
+  auto sorted_indices = torch::Tensor();
+  auto unsorted_indices = torch::Tensor();
+
+  Tensor output;
+  std::tuple<Tensor, Tensor> hidden;
+  std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
+
+  return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
+}
+
+std::tuple<PackedSequence, std::tuple<Tensor, Tensor>> LSTMImpl::forward_with_packed_input(
+  const PackedSequence& packed_input, torch::optional<std::tuple<Tensor, Tensor>> hx_opt) {
+  const auto& input = packed_input.data();
+  const auto& batch_sizes = packed_input.batch_sizes();
+  const auto& sorted_indices = packed_input.sorted_indices();
+  const auto& unsorted_indices = packed_input.unsorted_indices();
+  auto max_batch_size = batch_sizes[0].item<int64_t>();
+
+  Tensor output;
+  std::tuple<Tensor, Tensor> hidden;
+  std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx_opt));
+
+  auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
+  return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
 }
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GRU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 GRUImpl::GRUImpl(const GRUOptions& options_)
     : detail::RNNImplBase<GRUImpl>(
-          options_,
-          CuDNNMode::GRU,
-          /*number_of_gates=*/3) {}
+          detail::RNNOptionsBase(
+            torch::kGRU,
+            options_.input_size(),
+            options_.hidden_size())
+              .num_layers(options_.num_layers())
+              .bias(options_.bias())
+              .batch_first(options_.batch_first())
+              .dropout(options_.dropout())
+              .bidirectional(options_.bidirectional())),
+      options(options_) {}
 
-RNNOutput GRUImpl::forward(const Tensor& input, Tensor state) {
-  return generic_forward(
-      static_cast<RNNFunctionSignature*>(&torch::gru), input, std::move(state));
+std::tuple<Tensor, Tensor> GRUImpl::forward_helper(
+  const Tensor& input,
+  const Tensor& batch_sizes,
+  const Tensor& sorted_indices,
+  int64_t max_batch_size,
+  Tensor hx) {
+  if (!hx.defined()) {
+    int64_t num_directions = options.bidirectional() ? 2 : 1;
+    hx = torch::zeros({options.num_layers() * num_directions,
+                     max_batch_size, options.hidden_size()},
+                     torch::dtype(input.dtype()).device(input.device()));
+  } else {
+    // Each batch of the hidden state should match the input sequence that
+    // the user believes he/she is passing in.
+    hx = this->permute_hidden(hx, sorted_indices);
+  }    
+
+  this->check_forward_args(input, hx, batch_sizes);
+  std::tuple<Tensor, Tensor> result;
+  if (!batch_sizes.defined()) {
+    result = torch::gru(input, hx, flat_weights_, options.bias(), options.num_layers(),
+                     options.dropout(), this->is_training(), options.bidirectional(), options.batch_first());
+  } else {
+    result = torch::gru(input, batch_sizes, hx, flat_weights_, options.bias(),
+                     options.num_layers(), options.dropout(), this->is_training(), options.bidirectional());
+  }
+  auto output = std::get<0>(result);
+  auto hidden = std::get<1>(result);
+
+  return std::make_tuple(output, hidden);
+}
+
+std::tuple<Tensor, Tensor> GRUImpl::forward(const Tensor& input, Tensor hx) {
+  auto batch_sizes = torch::Tensor();
+  auto max_batch_size = options.batch_first() ? input.size(0) : input.size(1);
+  auto sorted_indices = torch::Tensor();
+  auto unsorted_indices = torch::Tensor();
+
+  Tensor output, hidden;
+  std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+  return std::make_tuple(output, this->permute_hidden(hidden, unsorted_indices));
+}
+
+std::tuple<PackedSequence, Tensor> GRUImpl::forward_with_packed_input(const PackedSequence& packed_input, Tensor hx) {
+  const auto& input = packed_input.data();
+  const auto& batch_sizes = packed_input.batch_sizes();
+  const auto& sorted_indices = packed_input.sorted_indices();
+  const auto& unsorted_indices = packed_input.unsorted_indices();
+  auto max_batch_size = batch_sizes[0].item<int64_t>();
+
+  Tensor output, hidden;
+  std::tie(output, hidden) = this->forward_helper(input, batch_sizes, sorted_indices, max_batch_size, std::move(hx));
+
+  auto output_packed = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices);
+  return std::make_tuple(output_packed, this->permute_hidden(hidden, unsorted_indices));
 }
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RNNCellImplBase ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/torch/csrc/api/src/nn/options/rnn.cpp b/torch/csrc/api/src/nn/options/rnn.cpp
index 33b6d53..d4a0841 100644
--- a/torch/csrc/api/src/nn/options/rnn.cpp
+++ b/torch/csrc/api/src/nn/options/rnn.cpp
@@ -5,21 +5,19 @@
 
 namespace detail {
 
-RNNOptionsBase::RNNOptionsBase(int64_t input_size, int64_t hidden_size)
-    : input_size_(input_size), hidden_size_(hidden_size) {}
+RNNOptionsBase::RNNOptionsBase(rnn_options_base_mode_t mode, int64_t input_size, int64_t hidden_size)
+    : mode_(mode), input_size_(input_size), hidden_size_(hidden_size) {}
 
 } // namespace detail
 
 RNNOptions::RNNOptions(int64_t input_size, int64_t hidden_size)
     : input_size_(input_size), hidden_size_(hidden_size) {}
 
-RNNOptions& RNNOptions::tanh() {
-  return activation(RNNActivation::Tanh);
-}
+LSTMOptions::LSTMOptions(int64_t input_size, int64_t hidden_size)
+    : input_size_(input_size), hidden_size_(hidden_size) {}
 
-RNNOptions& RNNOptions::relu() {
-  return activation(RNNActivation::ReLU);
-}
+GRUOptions::GRUOptions(int64_t input_size, int64_t hidden_size)
+    : input_size_(input_size), hidden_size_(hidden_size) {}
 
 namespace detail {