[C++ API] Better forward methods (#8739)

* Better forward methods in C++ API

capitalize error message in test_torch.test_flatten

Support for operator()

* Add operator() to Functional

* Get rid of SigmoidLinear

* Add BoundFunction to FunctionalImpl

* Remove macro from conv because it makes errors more nasty
diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp
index 9b6d228..172201a 100644
--- a/aten/src/ATen/ExpandUtils.cpp
+++ b/aten/src/ATen/ExpandUtils.cpp
@@ -14,26 +14,28 @@
     long dimB = dimsB - 1 - offset;
     long sizeA = (dimA >= 0) ? a[dimA] : 1;
     long sizeB = (dimB >= 0) ? b[dimB] : 1;
-    if (sizeA == sizeB || sizeA == 1 || sizeB == 1) {
-      expandedSizes[i] = std::max(sizeA, sizeB);
-    } else {
-      std::ostringstream oss;
-      oss << "The size of tensor a (" << sizeA << ") must match the size of tensor b ("
-          << sizeB << ") at non-singleton dimension " << i;
-      throw std::runtime_error(oss.str());
-    }
+
+    AT_CHECK(
+        sizeA == sizeB || sizeA == 1 || sizeB == 1,
+        "The size of tensor a (", sizeA,
+        ") must match the size of tensor b (", sizeB,
+        ") at non-singleton dimension ", i);
+
+    expandedSizes[i] = std::max(sizeA, sizeB);
   }
 
   return expandedSizes;
 }
 
-std::tuple<std::vector<int64_t>, std::vector<int64_t> >
-inferExpandGeometry(const Tensor &tensor, IntList sizes) {
+std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
+    const Tensor& tensor,
+    IntList sizes) {
   int64_t ndim = sizes.size();
 
   if (tensor.dim() == 0) {
     std::vector<int64_t> expandedStrides(ndim, 0);
-    return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(sizes.vec(), expandedStrides);
+    return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
+        sizes.vec(), expandedStrides);
   }
   std::vector<int64_t> expandedSizes(ndim);
   std::vector<int64_t> expandedStrides(ndim);
@@ -43,34 +45,35 @@
     int64_t offset = ndim - 1 - i;
     int64_t dim = tensor.dim() - 1 - offset;
     int64_t size = (dim >= 0) ? tensor.sizes()[dim] : 1;
-    int64_t stride = (dim >= 0) ?
-        tensor.strides()[dim] : expandedSizes[i + 1] * expandedStrides[i + 1];
+    int64_t stride = (dim >= 0) ? tensor.strides()[dim]
+                                : expandedSizes[i + 1] * expandedStrides[i + 1];
     int64_t targetSize = sizes[i];
     if (targetSize == -1) {
-      if (dim < 0) {
-        std::ostringstream oss;
-        oss << "The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, "
-            << "non-existing dimension " << i;
-        throw std::runtime_error(oss.str());
-      } else {
-        targetSize = size;
-      }
+      AT_CHECK(
+          dim >= 0,
+          "The expanded size of the tensor (",
+          targetSize,
+          ") isn't allowed in a leading, non-existing dimension ",
+          i);
+      targetSize = size;
     }
     if (size != targetSize) {
-      if (size == 1) {
-        size = targetSize;
-        stride = 0;
-      } else {
-        std::ostringstream oss;
-        oss << "The expanded size of the tensor (" << targetSize << ") must match the existing size (" << size 
-            << ") at non-singleton dimension " << i;
-        throw std::runtime_error(oss.str());
-      }
+      AT_CHECK(
+          size == 1,
+          "The expanded size of the tensor (",
+          targetSize,
+          ") must match the existing size (",
+          size,
+          ") at non-singleton dimension ",
+          i);
+      size = targetSize;
+      stride = 0;
     }
     expandedSizes[i] = size;
     expandedStrides[i] = stride;
   }
-  return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(expandedSizes, expandedStrides);
+  return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
+      expandedSizes, expandedStrides);
 }
 
-}
+} // namespace at
diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h
index 953f808..688b2ab 100644
--- a/aten/src/ATen/WrapDimUtils.h
+++ b/aten/src/ATen/WrapDimUtils.h
@@ -17,12 +17,10 @@
 
   int64_t min = -dim_post_expr;
   int64_t max = dim_post_expr - 1;
-  if (dim < min || dim > max) {
-    std::ostringstream oss;
-    oss << "dimension out of range (expected to be in range of [" << min
-        << ", " << max << "], but got " << dim << ")",
-    throw std::runtime_error(oss.str());
-  }
+  AT_CHECK(
+      dim >= min && dim <= max,
+      "Dimension out of range (expected to be in range of [",
+      min, ", ", max, "], but got ", dim, ")");
   if (dim < 0) dim += dim_post_expr;
   return dim;
 }
diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp
index 9786fcb..87ed246 100644
--- a/test/cpp/api/integration.cpp
+++ b/test/cpp/api/integration.cpp
@@ -230,7 +230,7 @@
   std::cout << "Num correct: " << correct.data().sum().toCFloat() << " out of "
             << telabel.size(0) << std::endl;
   return correct.data().sum().toCFloat() > telabel.size(0) * 0.8;
-};
+}
 
 TEST_CASE("integration/cartpole") {
   std::cerr << "Training episodic policy gradient with a critic for up to 3000"
@@ -245,16 +245,16 @@
   std::vector<torch::Tensor> saved_values;
   std::vector<float> rewards;
 
-  auto forward = [&](std::vector<torch::Tensor> inp) {
-    auto x = linear->forward(inp)[0].clamp_min(0);
-    torch::Tensor actions = policyHead->forward({x})[0];
-    torch::Tensor value = valueHead->forward({x})[0];
+  auto forward = [&](torch::Tensor inp) {
+    auto x = linear->forward(inp).clamp_min(0);
+    torch::Tensor actions = policyHead->forward(x);
+    torch::Tensor value = valueHead->forward(x);
     return std::make_tuple(at::softmax(actions, -1), value);
   };
 
-  auto selectAction = [&](torch::Tensor state) {
-    // Only work on single state now, change index to gather for batch
-    auto out = forward({state});
+  auto selectAction = [&](at::Tensor state) {
+    // Only work on single state right now, change index to gather for batch
+    auto out = forward(state);
     auto probs = torch::Tensor(std::get<0>(out));
     auto value = torch::Tensor(std::get<1>(out));
     auto action = probs.data().multinomial(1)[0].toCInt();
@@ -340,16 +340,15 @@
   auto linear2 = model->add(Linear(50, 10), "linear2");
 
   auto forward = [&](torch::Tensor x) {
-    x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2}))
-            .clamp_min(0);
-    x = conv2->forward({x})[0];
-    x = drop2d->forward({x})[0];
+    x = std::get<0>(at::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
+    x = conv2->forward(x);
+    x = drop2d->forward(x);
     x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
 
     x = x.view({-1, 320});
-    x = linear1->forward({x})[0].clamp_min(0);
-    x = drop->forward({x})[0];
-    x = linear2->forward({x})[0];
+    x = linear1->forward(x).clamp_min(0);
+    x = drop->forward(x);
+    x = linear2->forward(x);
     x = at::log_softmax(x, 1);
     return x;
   };
@@ -378,16 +377,15 @@
   auto linear2 = model->add(Linear(50, 10), "linear2");
 
   auto forward = [&](torch::Tensor x) {
-    x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2}))
-            .clamp_min(0);
-    x = batchnorm2d->forward({x})[0];
-    x = conv2->forward({x})[0];
+    x = std::get<0>(at::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
+    x = batchnorm2d->forward(x);
+    x = conv2->forward(x);
     x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
 
     x = x.view({-1, 320});
-    x = linear1->forward({x})[0].clamp_min(0);
-    x = batchnorm1->forward({x})[0];
-    x = linear2->forward({x})[0];
+    x = linear1->forward(x).clamp_min(0);
+    x = batchnorm1->forward(x);
+    x = linear2->forward(x);
     x = at::log_softmax(x, 1);
     return x;
   };
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index de4cd03..c2d505d 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -22,7 +22,7 @@
     torch::NoGradGuard guard;
     Linear model(5, 2);
     auto x = torch::randn({10, 5}, at::requires_grad());
-    auto y = model->forward({x})[0];
+    auto y = model->forward(x);
     torch::Tensor s = y.sum();
 
     s.backward();
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index 6e0f887..2db70584 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -38,7 +38,7 @@
 TEST_CASE("module/zero-grad") {
   Linear module(3, 4);
   auto weight = torch::ones({8, 3}, at::requires_grad());
-  auto loss = module->forward({weight}).front().sum();
+  auto loss = module->forward(weight).sum();
   loss.backward();
   for (auto& parameter : module->parameters()) {
     auto grad = parameter->grad();
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 80a8cbf..64d6dc3 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -21,10 +21,6 @@
     l3 = register_module("l3", Linear(5, 100));
   }
 
-  std::vector<torch::Tensor> forward(std::vector<torch::Tensor> input) {
-    return input;
-  }
-
   Linear l1, l2, l3;
 };
 
@@ -36,10 +32,6 @@
     param_ = register_parameter("param", torch::empty({3, 2, 21}));
   }
 
-  std::vector<torch::Tensor> forward(std::vector<torch::Tensor> input) {
-    return input;
-  };
-
   torch::Tensor param_;
   Linear l1;
   std::shared_ptr<TestModel> t;
@@ -50,7 +42,7 @@
     SECTION("1d") {
       Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
       auto x = torch::randn({2, 3, 5}, at::requires_grad());
-      auto y = model->forward({x})[0];
+      auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
@@ -66,7 +58,7 @@
       SECTION("even") {
         Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
         auto x = torch::randn({2, 3, 5, 5}, at::requires_grad());
-        auto y = model->forward({x})[0];
+        auto y = model->forward(x);
         torch::Tensor s = y.sum();
 
         s.backward();
@@ -82,7 +74,7 @@
       SECTION("uneven") {
         Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
         auto x = torch::randn({2, 3, 5, 4}, at::requires_grad());
-        auto y = model->forward({x})[0];
+        auto y = model->forward(x);
         torch::Tensor s = y.sum();
 
         s.backward();
@@ -98,7 +90,7 @@
     SECTION("3d") {
       Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
       auto x = torch::randn({2, 3, 5, 5, 5}, at::requires_grad());
-      auto y = model->forward({x})[0];
+      auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
@@ -116,7 +108,7 @@
     SECTION("basic1") {
       Linear model(5, 2);
       auto x = torch::randn({10, 5}, at::requires_grad());
-      auto y = model->forward({x})[0];
+      auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
@@ -136,9 +128,9 @@
     auto l3 = model->add(Linear(5, 100), "l3");
 
     auto x = torch::randn({1000, 10}, at::requires_grad());
-    x = l1->forward({x})[0].clamp_min(0);
-    x = l2->forward({x})[0].clamp_min(0);
-    x = l3->forward({x})[0].clamp_min(0);
+    x = l1->forward(x).clamp_min(0);
+    x = l2->forward(x).clamp_min(0);
+    x = l3->forward(x).clamp_min(0);
 
     x.backward();
     REQUIRE(x.ndimension() == 2);
@@ -154,7 +146,7 @@
       // Cannot get gradients to change indices (input) - only for embedding
       // params
       auto x = torch::full({10}, dict_size - 1, torch::kInt64);
-      auto y = model->forward({x})[0];
+      auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
@@ -169,7 +161,7 @@
     SECTION("list") {
       Embedding model(6, 4);
       auto x = torch::full({2, 3}, 5, torch::kInt64);
-      auto y = model->forward({x})[0];
+      auto y = model->forward(x);
       torch::Tensor s = y.sum();
 
       s.backward();
@@ -183,7 +175,7 @@
   SECTION("dropout") {
     Dropout dropout(0.5);
     torch::Tensor x = torch::ones(100, at::requires_grad());
-    torch::Tensor y = dropout->forward({x})[0];
+    torch::Tensor y = dropout->forward(x);
 
     y.backward();
     REQUIRE(y.ndimension() == 1);
@@ -194,7 +186,7 @@
     // REQUIRE(y.sum().toCFloat() > 70); // Probably
 
     dropout->eval();
-    y = dropout->forward({x})[0];
+    y = dropout->forward(x);
     REQUIRE(y.data().sum().toCFloat() == 100);
   }
 
@@ -219,17 +211,31 @@
   }
 
   SECTION("functional") {
-    bool was_called = false;
-    // clang-format off
-    auto functional = Functional([&was_called](std::vector<torch::Tensor> input) {
-      was_called = true;
-      return input;
-    });
-    // clang-format on
-    auto output = functional->forward({torch::ones(5, at::requires_grad())});
-    REQUIRE(was_called);
-    REQUIRE(output.size() == 1);
-    REQUIRE(output.front().equal(torch::ones(5, at::requires_grad())));
+    {
+      bool was_called = false;
+      auto functional = Functional([&was_called](torch::Tensor input) {
+        was_called = true;
+        return input;
+      });
+      auto output = functional->forward(torch::ones(5, at::requires_grad()));
+      REQUIRE(was_called);
+      REQUIRE(output.equal(torch::ones(5, at::requires_grad())));
+
+      was_called = false;
+      output = functional(torch::ones(5, at::requires_grad()));
+      REQUIRE(was_called);
+      REQUIRE(output.equal(torch::ones(5, at::requires_grad())));
+    }
+    {
+      auto functional = Functional(at::relu);
+      REQUIRE(functional(torch::ones({})).data().toCFloat() == 1);
+      REQUIRE(functional(torch::ones({})).toCFloat() == 1);
+      REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
+    }
+    {
+      auto functional = Functional(at::elu, /*alpha=*/1, /*scale=*/0);
+      REQUIRE(functional(torch::ones({})).toCFloat() == 0);
+    }
   }
 }
 
@@ -238,7 +244,7 @@
     Linear model(5, 2);
     model->cuda();
     auto x = torch::randn({10, 5}, at::device(at::kCUDA).requires_grad(true));
-    auto y = model->forward({x})[0];
+    auto y = model->forward(x);
     torch::Tensor s = y.sum();
 
     s.backward();
@@ -255,7 +261,7 @@
     model->cuda();
     model->cpu();
     auto x = torch::randn({10, 5}, at::requires_grad());
-    auto y = model->forward({x})[0];
+    auto y = model->forward(x);
     torch::Tensor s = y.sum();
 
     s.backward();
diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp
index 75ee0f3..e237493 100644
--- a/test/cpp/api/optim.cpp
+++ b/test/cpp/api/optim.cpp
@@ -1,6 +1,7 @@
 #include <catch.hpp>
 
 #include <torch/nn/module.h>
+#include <torch/nn/modules/functional.h>
 #include <torch/nn/modules/linear.h>
 #include <torch/nn/modules/sequential.h>
 #include <torch/optim.h>
@@ -62,18 +63,22 @@
 
   torch::manual_seed(0);
   Sequential model(
-      torch::SigmoidLinear(Linear(2, 3)), torch::SigmoidLinear(Linear(3, 1)));
+      Linear(2, 3),
+      Functional(at::sigmoid),
+      Linear(3, 1),
+      Functional(at::sigmoid));
+
   model.to(torch::kFloat64);
 
   // Use exact input values because matching random values is hard.
   auto parameters = model.parameters();
-  parameters.at("0.linear.weight").data().flatten() = at::tensor(
+  parameters.at("0.weight").data().flatten() = at::tensor(
       {-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, torch::kFloat64);
-  parameters.at("0.linear.bias").data() =
+  parameters.at("0.bias").data() =
       at::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64);
-  parameters.at("1.linear.weight").data().flatten() =
+  parameters.at("2.weight").data().flatten() =
       at::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64);
-  parameters.at("1.linear.bias").data() =
+  parameters.at("2.bias").data() =
       at::tensor({-0.0711}, torch::kFloat64);
 
   auto optimizer = OptimizerClass(parameters, options);
@@ -111,7 +116,10 @@
   std::srand(0);
   torch::manual_seed(0);
   Sequential model(
-      torch::SigmoidLinear(Linear(2, 8)), torch::SigmoidLinear(Linear(8, 1)));
+      Linear(2, 8),
+      Functional(at::sigmoid),
+      Linear(8, 1),
+      Functional(at::sigmoid));
 
   SECTION("sgd") {
     REQUIRE(test_optimizer_xor(
@@ -195,7 +203,7 @@
     REQUIRE(!parameter->grad().defined());
   }
 
-  auto output = model->forward({torch::ones({5, 2})}).front();
+  auto output = model->forward(torch::ones({5, 2}));
   auto loss = output.sum();
   loss.backward();
 
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index 20a64f7..f45754a 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -22,9 +22,9 @@
     auto T = x.size(0);
     auto B = x.size(1);
     x = x.view({T * B, 1});
-    x = l1->forward({x})[0].view({T, B, nhid}).tanh_();
-    x = rnn->forward({x})[0][T - 1];
-    x = lo->forward({x})[0];
+    x = l1->forward(x).view({T, B, nhid}).tanh_();
+    x = rnn->forward(x).output[T - 1];
+    x = lo->forward(x);
     return x;
   };
 
@@ -61,26 +61,23 @@
   return true;
 };
 
-void check_lstm_sizes(std::vector<torch::Tensor> tup) {
+void check_lstm_sizes(RNNOutput output) {
   // Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
   // 10 and 16 time steps (10 x 16 x n)
 
-  auto out = tup[0];
-  auto hids = tup[1];
+  REQUIRE(output.output.ndimension() == 3);
+  REQUIRE(output.output.size(0) == 10);
+  REQUIRE(output.output.size(1) == 16);
+  REQUIRE(output.output.size(2) == 64);
 
-  REQUIRE(out.ndimension() == 3);
-  REQUIRE(out.size(0) == 10);
-  REQUIRE(out.size(1) == 16);
-  REQUIRE(out.size(2) == 64);
-
-  REQUIRE(hids.ndimension() == 4);
-  REQUIRE(hids.size(0) == 2); // (hx, cx)
-  REQUIRE(hids.size(1) == 3); // layers
-  REQUIRE(hids.size(2) == 16); // Batchsize
-  REQUIRE(hids.size(3) == 64); // 64 hidden dims
+  REQUIRE(output.state.ndimension() == 4);
+  REQUIRE(output.state.size(0) == 2); // (hx, cx)
+  REQUIRE(output.state.size(1) == 3); // layers
+  REQUIRE(output.state.size(2) == 16); // Batchsize
+  REQUIRE(output.state.size(3) == 64); // 64 hidden dims
 
   // Something is in the hiddens
-  REQUIRE(hids.norm().toCFloat() > 0);
+  REQUIRE(output.state.norm().toCFloat() > 0);
 }
 
 TEST_CASE("rnn") {
@@ -88,17 +85,17 @@
     SECTION("sizes") {
       LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
       auto x = torch::randn({10, 16, 128}, at::requires_grad());
-      auto tup = model->forward({x});
+      auto output = model->forward(x);
       auto y = x.mean();
 
       y.backward();
-      check_lstm_sizes(tup);
+      check_lstm_sizes(output);
 
-      auto next = model->forward({x, tup[1]});
+      auto next = model->forward(x, output.state);
 
       check_lstm_sizes(next);
 
-      torch::Tensor diff = next[1] - tup[1];
+      torch::Tensor diff = next.state - output.state;
 
       // Hiddens changed
       REQUIRE(diff.data().abs().sum().toCFloat() > 1e-3);
@@ -122,13 +119,13 @@
         p[i] = (size - i) / size;
       }
 
-      auto out = model->forward({x});
-      REQUIRE(out[0].ndimension() == 3);
-      REQUIRE(out[0].size(0) == 3);
-      REQUIRE(out[0].size(1) == 4);
-      REQUIRE(out[0].size(2) == 2);
+      auto out = model->forward(x);
+      REQUIRE(out.output.ndimension() == 3);
+      REQUIRE(out.output.size(0) == 3);
+      REQUIRE(out.output.size(1) == 4);
+      REQUIRE(out.output.size(2) == 2);
 
-      auto flat = out[0].data().view(3 * 4 * 2);
+      auto flat = out.output.data().view(3 * 4 * 2);
       float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
                        0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
                        0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
@@ -137,12 +134,12 @@
         REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
       }
 
-      REQUIRE(out[1].ndimension() == 4); // (hx, cx) x layers x B x 2
-      REQUIRE(out[1].size(0) == 2);
-      REQUIRE(out[1].size(1) == 1);
-      REQUIRE(out[1].size(2) == 4);
-      REQUIRE(out[1].size(3) == 2);
-      flat = out[1].data().view(16);
+      REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
+      REQUIRE(out.state.size(0) == 2);
+      REQUIRE(out.state.size(1) == 1);
+      REQUIRE(out.state.size(2) == 4);
+      REQUIRE(out.state.size(3) == 2);
+      flat = out.state.data().view(16);
       float h_out[] = {0.7889,
                        0.9003,
                        0.7769,
@@ -192,17 +189,17 @@
     LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
     model->cuda();
     auto x = torch::randn({10, 16, 128}, at::requires_grad().device(at::kCUDA));
-    auto tup = model->forward({x});
+    auto output = model->forward(x);
     auto y = x.mean();
 
     y.backward();
-    check_lstm_sizes(tup);
+    check_lstm_sizes(output);
 
-    auto next = model->forward({x, tup[1]});
+    auto next = model->forward(x, output.state);
 
     check_lstm_sizes(next);
 
-    torch::Tensor diff = next[1] - tup[1];
+    torch::Tensor diff = next.state - output.state;
 
     // Hiddens changed
     REQUIRE(diff.data().abs().sum().toCFloat() > 1e-3);
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index a55eede..9c79a93 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -1,5 +1,6 @@
 #include <catch.hpp>
 
+#include <torch/nn/modules.h>
 #include <torch/nn/modules/linear.h>
 #include <torch/nn/modules/sequential.h>
 #include <torch/tensor.h>
@@ -177,12 +178,19 @@
     Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
 
     auto x = torch::randn({1000, 10}, at::requires_grad());
-    auto y =
-        sequential
-            .forward<std::vector<torch::Tensor>>(std::vector<torch::Tensor>{x})
-            .front();
+    auto y = sequential.forward(x);
     REQUIRE(y.ndimension() == 2);
     REQUIRE(y.size(0) == 1000);
     REQUIRE(y.size(1) == 100);
   }
+
+  SECTION("can hold other important modules") {
+    Sequential sequential(
+        Linear(10, 3),
+        Conv2d(1, 2, 3),
+        Dropout(0.5),
+        BatchNorm(5),
+        Embedding(4, 10),
+        LSTM(4, 5));
+  }
 }
diff --git a/test/cpp/api/serialization.cpp b/test/cpp/api/serialization.cpp
index f997a03..9aaa327 100644
--- a/test/cpp/api/serialization.cpp
+++ b/test/cpp/api/serialization.cpp
@@ -1,5 +1,6 @@
 #include <catch.hpp>
 
+#include <torch/nn/modules/functional.h>
 #include <torch/nn/modules/linear.h>
 #include <torch/nn/modules/sequential.h>
 #include <torch/optim/optimizer.h>
@@ -21,7 +22,10 @@
 namespace {
 std::shared_ptr<Sequential> xor_model() {
   return std::make_shared<Sequential>(
-      torch::SigmoidLinear(2, 8), torch::SigmoidLinear(8, 1));
+      Linear(2, 8),
+      Functional(at::sigmoid),
+      Linear(8, 1),
+      Functional(at::sigmoid));
 }
 } // namespace
 
@@ -244,7 +248,7 @@
 
     auto step = [&](torch::optim::Optimizer& optimizer, Linear model) {
       optimizer.zero_grad();
-      auto y = model->forward({x})[0].sum();
+      auto y = model->forward(x).sum();
       y.backward();
       optimizer.step();
     };
diff --git a/test/cpp/api/util.h b/test/cpp/api/util.h
index f204797..32b6221 100644
--- a/test/cpp/api/util.h
+++ b/test/cpp/api/util.h
@@ -1,8 +1,7 @@
-#include <torch/nn/module.h>
-#include <torch/nn/modules/linear.h>
+#pragma once
 
-#include <memory>
-#include <stdexcept>
+#include <torch/nn/cloneable.h>
+
 #include <string>
 #include <utility>
 
@@ -12,12 +11,6 @@
 // for experimental implementations
 class SimpleContainer : public nn::Cloneable<SimpleContainer> {
  public:
-  virtual std::vector<Tensor> forward(std::vector<Tensor>) {
-    throw std::runtime_error(
-        "SimpleContainer has no forward, maybe you"
-        " wanted to subclass and override this function?");
-  }
-
   void reset() override {}
 
   template <typename ModuleHolder>
@@ -27,19 +20,4 @@
     return Module::register_module(std::move(name), module_holder);
   }
 };
-
-struct SigmoidLinear : nn::Module {
-  SigmoidLinear(int64_t in, int64_t out) : linear(nn::Linear(in, out)) {
-    register_module("linear", linear);
-  }
-
-  explicit SigmoidLinear(nn::Linear linear_) : linear(std::move(linear_)) {
-    register_module("linear", linear);
-  }
-  Tensor forward(Tensor input) {
-    return linear->forward({input}).front().sigmoid();
-  }
-  nn::Linear linear;
-};
-
 } // namespace torch
diff --git a/test/test_torch.py b/test/test_torch.py
index a1c84c8..edc3f85 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -5526,7 +5526,7 @@
         self.assertEqual(flat, src)
 
         # out of bounds index
-        with self.assertRaisesRegex(RuntimeError, 'dimension out of range'):
+        with self.assertRaisesRegex(RuntimeError, 'Dimension out of range'):
             src.flatten(5, 10)
 
         # invalid start and end
diff --git a/torch/csrc/api/include/torch/expanding_array.h b/torch/csrc/api/include/torch/expanding_array.h
index f03005a..24f6977 100644
--- a/torch/csrc/api/include/torch/expanding_array.h
+++ b/torch/csrc/api/include/torch/expanding_array.h
@@ -70,7 +70,7 @@
   }
 
   /// Returns an `ArrayRef` to the underlying `std::array`.
-  operator at::ArrayRef<T>() {
+  operator at::ArrayRef<T>() const {
     return values_;
   }
 
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index c8949b1..0b662bb 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -22,7 +22,10 @@
   explicit BatchNormImpl(BatchNormOptions options);
 
   void reset() override;
-  std::vector<Tensor> forward(std::vector<Tensor>);
+
+  Tensor forward(Tensor input);
+  Tensor pure_forward(Tensor input, Tensor mean, Tensor variance);
+
   const BatchNormOptions& options() const noexcept;
 
  private:
diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h
index fab17c2..f7a2dd1 100644
--- a/torch/csrc/api/include/torch/nn/modules/conv.h
+++ b/torch/csrc/api/include/torch/nn/modules/conv.h
@@ -43,20 +43,35 @@
   ConvOptions<D> options_;
 };
 
-#define CONV_D(D)                                               \
-  class Conv##D##dImpl : public ConvImpl<D, Conv##D##dImpl> {   \
-   public:                                                      \
-    using ConvImpl<D, Conv##D##dImpl>::ConvImpl;                \
-    std::vector<Tensor> forward(std::vector<Tensor> input); \
-  };                                                            \
-  using Conv##D##dOptions = ConvOptions<D>;                     \
-  TORCH_MODULE(Conv##D##d)
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-CONV_D(1);
-CONV_D(2);
-CONV_D(3);
+class Conv1dImpl : public ConvImpl<1, Conv1dImpl> {
+ public:
+  using ConvImpl<1, Conv1dImpl>::ConvImpl;
+  Tensor forward(Tensor input);
+};
+using Conv1dOptions = ConvOptions<1>;
+TORCH_MODULE(Conv1d);
 
-#undef CONV_D
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+class Conv2dImpl : public ConvImpl<2, Conv2dImpl> {
+ public:
+  using ConvImpl<2, Conv2dImpl>::ConvImpl;
+  Tensor forward(Tensor input);
+};
+using Conv2dOptions = ConvOptions<2>;
+TORCH_MODULE(Conv2d);
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+class Conv3dImpl : public ConvImpl<3, Conv3dImpl> {
+ public:
+  using ConvImpl<3, Conv3dImpl>::ConvImpl;
+  Tensor forward(Tensor input);
+};
+using Conv3dOptions = ConvOptions<3>;
+TORCH_MODULE(Conv3d);
 
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h
index 901ee3b..5a5762e 100644
--- a/torch/csrc/api/include/torch/nn/modules/dropout.h
+++ b/torch/csrc/api/include/torch/nn/modules/dropout.h
@@ -21,7 +21,7 @@
   explicit DropoutImplBase(DropoutOptions options);
 
   void reset() override;
-  std::vector<Tensor> forward(std::vector<Tensor> input);
+  Tensor forward(Tensor input);
   const DropoutOptions& options() const noexcept;
 
  protected:
diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h
index fcedb72..c345ca4 100644
--- a/torch/csrc/api/include/torch/nn/modules/embedding.h
+++ b/torch/csrc/api/include/torch/nn/modules/embedding.h
@@ -21,7 +21,7 @@
   explicit EmbeddingImpl(EmbeddingOptions options);
 
   void reset() override;
-  std::vector<Tensor> forward(std::vector<Tensor>);
+  Tensor forward(Tensor);
   const EmbeddingOptions& options() const noexcept;
 
  private:
diff --git a/torch/csrc/api/include/torch/nn/modules/functional.h b/torch/csrc/api/include/torch/nn/modules/functional.h
index 21a42b9..e3a9d68 100644
--- a/torch/csrc/api/include/torch/nn/modules/functional.h
+++ b/torch/csrc/api/include/torch/nn/modules/functional.h
@@ -1,11 +1,11 @@
 #pragma once
 
+#include <torch/csrc/utils/variadic.h>
 #include <torch/nn/cloneable.h>
 #include <torch/nn/pimpl.h>
 #include <torch/tensor.h>
 
 #include <functional>
-#include <vector>
 
 namespace torch {
 namespace nn {
@@ -14,13 +14,41 @@
 // Sequential.
 class FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
  public:
-  using Function = std::function<std::vector<Tensor>(std::vector<Tensor>)>;
+  using Function = std::function<Tensor(Tensor)>;
+
+  /// A small type that is used only in the constructor of `FunctionalImpl`,
+  /// that allows constructing it with a function with more than one argument,
+  /// and binding all but the first parameter to specific values. It is
+  /// necessary due to interaction with the `ModuleHolder` class, which expects
+  /// to construct a module with `Module({...})` when there is more than one
+  /// argument. It also deals with argument binding.
+  struct BoundFunction {
+    template <
+        typename AnyFunction,
+        typename... Args,
+        typename = torch::enable_if_t<(sizeof...(Args) > 0)>>
+    /* implicit */ BoundFunction(AnyFunction original_function, Args&&... args)
+        : function_(std::bind(
+              original_function,
+              /*input=*/std::placeholders::_1,
+              std::forward<Args>(args)...)) {
+      // std::bind is normally evil, but (1) gcc is broken w.r.t. handling
+      // parameter pack expansion in lambdas and (2) moving parameter packs into
+      // a lambda only works with C++14, so std::bind is the more move-aware
+      // solution here.
+    }
+
+    Function function_;
+  };
 
   explicit FunctionalImpl(Function function);
-  explicit FunctionalImpl(std::function<Tensor(Tensor)> function);
+  explicit FunctionalImpl(BoundFunction bound_function);
 
   void reset() override;
-  std::vector<Tensor> forward(std::vector<Tensor> input);
+  Tensor forward(Tensor input);
+
+  /// Calls forward(input).
+  Tensor operator()(Tensor input);
 
  private:
   Function function_;
diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h
index 607fadd..612a2e4 100644
--- a/torch/csrc/api/include/torch/nn/modules/linear.h
+++ b/torch/csrc/api/include/torch/nn/modules/linear.h
@@ -17,12 +17,12 @@
   TORCH_ARG(bool, with_bias) = true;
 };
 
-class LinearImpl : public torch::nn::Cloneable<LinearImpl> {
+class LinearImpl : public Cloneable<LinearImpl> {
  public:
   explicit LinearImpl(LinearOptions options);
 
   void reset() override;
-  std::vector<Tensor> forward(std::vector<Tensor>);
+  Tensor forward(Tensor);
   const LinearOptions& options() const noexcept;
 
  private:
diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h
index 2fa7aa2..7fc8a2a8 100644
--- a/torch/csrc/api/include/torch/nn/modules/rnn.h
+++ b/torch/csrc/api/include/torch/nn/modules/rnn.h
@@ -16,6 +16,12 @@
 
 namespace torch {
 namespace nn {
+
+struct RNNOutput {
+  Tensor output;
+  Tensor state;
+};
+
 namespace detail {
 struct RNNOptionsBase {
   RNNOptionsBase(int64_t input_size, int64_t hidden_size);
@@ -40,7 +46,7 @@
       int64_t number_of_gates = 1,
       bool has_cell_state = false);
 
-  std::vector<Tensor> forward(std::vector<Tensor>);
+  RNNOutput forward(Tensor input, Tensor state = {});
 
   void reset() override;
 
@@ -55,12 +61,10 @@
   void flatten_parameters_for_cudnn();
 
  protected:
-  virtual std::vector<Tensor> cell_forward(
-      std::vector<Tensor>,
-      int64_t layer) = 0;
+  virtual Tensor cell_forward(Tensor input, Tensor state, int64_t layer) = 0;
 
-  std::vector<Tensor> CUDNN_forward(std::vector<Tensor>);
-  std::vector<Tensor> autograd_forward(std::vector<Tensor>);
+  RNNOutput CUDNN_forward(Tensor input, Tensor state);
+  RNNOutput autograd_forward(Tensor input, Tensor state);
 
   std::vector<Tensor> flat_weights() const;
 
@@ -114,7 +118,7 @@
   const RNNOptions& options() const noexcept;
 
  private:
-  std::vector<Tensor> cell_forward(std::vector<Tensor>, int64_t layer) override;
+  Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
 
   RNNOptions options_;
   std::function<Tensor(Tensor)> activation_function_;
@@ -133,7 +137,7 @@
   const LSTMOptions& options() const noexcept;
 
  private:
-  std::vector<Tensor> cell_forward(std::vector<Tensor>, int64_t layer) override;
+  Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
 };
 
 TORCH_MODULE(LSTM);
@@ -149,7 +153,7 @@
   const GRUOptions& options() const noexcept;
 
  private:
-  std::vector<Tensor> cell_forward(std::vector<Tensor>, int64_t layer) override;
+  Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
 };
 
 TORCH_MODULE(GRU);
diff --git a/torch/csrc/api/include/torch/nn/pimpl.h b/torch/csrc/api/include/torch/nn/pimpl.h
index 5f772c5..30bfcd4 100644
--- a/torch/csrc/api/include/torch/nn/pimpl.h
+++ b/torch/csrc/api/include/torch/nn/pimpl.h
@@ -1,6 +1,7 @@
 #pragma once
 
 #include <torch/csrc/utils/variadic.h>
+#include <torch/tensor.h>
 
 #include <memory>
 #include <type_traits>
@@ -71,6 +72,12 @@
     return impl_.get();
   }
 
+  /// Forwards to the call operator of the contained module.
+  template <typename... Args>
+  Tensor operator()(Args&&... args) {
+    return (*impl_)(std::forward<Args>(args)...);
+  }
+
   /// Returns the underlying module.
   const std::shared_ptr<Contained>& get() const {
     AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index 267a09e..6de6efa 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -33,11 +33,13 @@
   }
 }
 
-std::vector<Tensor> BatchNormImpl::forward(std::vector<Tensor> inputs) {
-  auto& input = inputs[0];
-  auto& running_mean_ = (options_.stateful_ ? this->running_mean_ : inputs[1]);
-  auto& running_variance_ =
-      (options_.stateful_ ? this->running_variance_ : inputs[2]);
+Tensor BatchNormImpl::forward(Tensor input) {
+  return pure_forward(input, Tensor(), Tensor());
+}
+
+Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) {
+  auto& running_mean = options_.stateful_ ? running_mean_ : mean;
+  auto& running_variance = options_.stateful_ ? running_variance_ : variance;
 
   if (is_training()) {
     const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
@@ -46,18 +48,16 @@
         "BatchNorm expected more than 1 value per channel when training!");
   }
 
-  auto output = at::batch_norm(
+  return at::batch_norm(
       input,
       weight_,
       bias_,
-      running_mean_,
-      running_variance_,
+      running_mean,
+      running_variance,
       is_training(),
       options_.momentum_,
       options_.eps_,
       torch::cuda::cudnn_is_available());
-
-  return std::vector<Tensor>({output});
 }
 
 const BatchNormOptions& BatchNormImpl::options() const noexcept {
diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp
index 1efd0ad..36ffb83 100644
--- a/torch/csrc/api/src/nn/modules/conv.cpp
+++ b/torch/csrc/api/src/nn/modules/conv.cpp
@@ -70,88 +70,87 @@
   return options_;
 }
 
-std::vector<Tensor> Conv1dImpl::forward(std::vector<Tensor> input) {
-  AT_ASSERT(input.front().ndimension() == 3);
+Tensor Conv1dImpl::forward(Tensor input) {
+  AT_ASSERT(input.ndimension() == 3);
 
   if (options_.transposed_) {
-    return {at::conv_transpose1d(
-        input.front(),
+    return at::conv_transpose1d(
+        input,
         weight_,
         bias_,
         options_.stride_,
         options_.padding_,
         options_.output_padding_,
         options_.groups_,
-        options_.dilation_)};
+        options_.dilation_);
   }
-  return {at::conv1d(
-      input.front(),
+  return at::conv1d(
+      input,
       weight_,
       bias_,
       options_.stride_,
       options_.padding_,
       options_.dilation_,
-      options_.groups_)};
+      options_.groups_);
 }
 
-std::vector<Tensor> Conv2dImpl::forward(std::vector<Tensor> input) {
-  AT_ASSERT(input.front().ndimension() == 4);
+Tensor Conv2dImpl::forward(Tensor input) {
+  AT_ASSERT(input.ndimension() == 4);
 
   if (options_.transposed_) {
-    return {at::conv_transpose2d(
-        input.front(),
+    return at::conv_transpose2d(
+        input,
         weight_,
         bias_,
         options_.stride_,
         options_.padding_,
         options_.output_padding_,
         options_.groups_,
-        options_.dilation_)};
+        options_.dilation_);
   }
-  return {at::conv2d(
-      input.front(),
+  return at::conv2d(
+      input,
       weight_,
       bias_,
       options_.stride_,
       options_.padding_,
       options_.dilation_,
-      options_.groups_)};
+      options_.groups_);
 }
 
-std::vector<Tensor> Conv3dImpl::forward(std::vector<Tensor> input) {
-  AT_ASSERT(input.front().ndimension() == 5);
+Tensor Conv3dImpl::forward(Tensor input) {
+  AT_ASSERT(input.ndimension() == 5);
 
   if (options_.transposed_) {
-    return {at::conv_transpose3d(
-        input.front(),
+    return at::conv_transpose3d(
+        input,
         weight_,
         bias_,
         options_.stride_,
         options_.padding_,
         options_.output_padding_,
         options_.groups_,
-        options_.dilation_)};
+        options_.dilation_);
   } else {
-    return {at::conv3d(
-        input.front(),
+    return at::conv3d(
+        input,
         weight_,
         bias_,
         options_.stride_,
         options_.padding_,
         options_.dilation_,
-        options_.groups_)};
+        options_.groups_);
   }
 }
 
-#define CONV_D(D)                 \
-  template struct ConvOptions<D>; \
-  template class ConvImpl<D, Conv##D##dImpl>
+template struct ConvOptions<1>;
+template class ConvImpl<1, Conv1dImpl>;
 
-CONV_D(1);
-CONV_D(2);
-CONV_D(3);
+template struct ConvOptions<2>;
+template class ConvImpl<2, Conv2dImpl>;
 
-#undef CONV_D
+template struct ConvOptions<3>;
+template class ConvImpl<3, Conv3dImpl>;
 
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp
index ee5006e..2b544b2 100644
--- a/torch/csrc/api/src/nn/modules/dropout.cpp
+++ b/torch/csrc/api/src/nn/modules/dropout.cpp
@@ -21,19 +21,16 @@
 void DropoutImplBase<Derived>::reset() {}
 
 template <typename Derived>
-std::vector<Tensor> DropoutImplBase<Derived>::forward(
-    std::vector<Tensor> input) {
+Tensor DropoutImplBase<Derived>::forward(Tensor input) {
   if (options_.rate_ == 0 || !this->is_training()) {
     return input;
   }
-  std::vector<Tensor> output;
-  for (const auto& value : input) {
-    const auto noise = (noise_mask(value).uniform_(0, 1) > options_.rate_)
-                           .toType(value.type().scalarType())
-                           .mul_(1.0f / (1.0f - options_.rate_));
-    output.push_back(value * noise);
-  }
-  return output;
+
+  auto scale = 1.0f / (1.0f - options_.rate_);
+  auto boolean_mask = noise_mask(input).uniform_(0, 1) > options_.rate_;
+  auto noise = boolean_mask.to(input.dtype()).mul_(scale);
+
+  return input * noise;
 }
 
 template <typename Derived>
diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp
index f69c8a5..93cb876 100644
--- a/torch/csrc/api/src/nn/modules/embedding.cpp
+++ b/torch/csrc/api/src/nn/modules/embedding.cpp
@@ -23,8 +23,8 @@
   table_.data().normal_(0, 1);
 }
 
-std::vector<Tensor> EmbeddingImpl::forward(std::vector<Tensor> input) {
-  return {at::embedding(table_, /*indices=*/input[0])};
+Tensor EmbeddingImpl::forward(Tensor input) {
+  return at::embedding(table_, /*indices=*/input);
 }
 
 const EmbeddingOptions& EmbeddingImpl::options() const noexcept {
diff --git a/torch/csrc/api/src/nn/modules/functional.cpp b/torch/csrc/api/src/nn/modules/functional.cpp
index 5245adf..d4a4cc3 100644
--- a/torch/csrc/api/src/nn/modules/functional.cpp
+++ b/torch/csrc/api/src/nn/modules/functional.cpp
@@ -4,22 +4,23 @@
 
 #include <functional>
 #include <utility>
-#include <vector>
 
 namespace torch {
 namespace nn {
-FunctionalImpl::FunctionalImpl(Function function)
+FunctionalImpl::FunctionalImpl(std::function<Tensor(Tensor)> function)
     : function_(std::move(function)) {}
 
-FunctionalImpl::FunctionalImpl(std::function<Tensor(Tensor)> function)
-    : function_([function](std::vector<Tensor> input) {
-        return std::vector<Tensor>({function(input.front())});
-      }) {}
+FunctionalImpl::FunctionalImpl(BoundFunction bound_function)
+    : function_(std::move(bound_function.function_)) {}
 
 void FunctionalImpl::reset() {}
 
-std::vector<Tensor> FunctionalImpl::forward(std::vector<Tensor> input) {
+Tensor FunctionalImpl::forward(Tensor input) {
   return function_(input);
 }
+
+Tensor FunctionalImpl::operator()(Tensor input) {
+  return forward(input);
+}
 } // namespace nn
 } // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp
index d1178c8..065c842 100644
--- a/torch/csrc/api/src/nn/modules/linear.cpp
+++ b/torch/csrc/api/src/nn/modules/linear.cpp
@@ -26,19 +26,18 @@
   }
 }
 
-std::vector<Tensor> LinearImpl::forward(std::vector<Tensor> input) {
-  auto x = input[0];
-  if (x.ndimension() == 2 && options_.with_bias_) {
+Tensor LinearImpl::forward(Tensor input) {
+  if (input.ndimension() == 2 && options_.with_bias_) {
     // Fused op is marginally faster
-    AT_ASSERT(x.size(1) == weight_.size(1));
-    return {at::addmm(bias_, x, weight_.t())};
+    AT_ASSERT(input.size(1) == weight_.size(1));
+    return {at::addmm(bias_, input, weight_.t())};
   }
 
-  auto output = x.matmul(weight_.t());
+  auto output = input.matmul(weight_.t());
   if (options_.with_bias_) {
     output += bias_;
   }
-  return {output};
+  return output;
 }
 
 const LinearOptions& LinearImpl::options() const noexcept {
diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp
index 1a5f367..88a7595 100644
--- a/torch/csrc/api/src/nn/modules/rnn.cpp
+++ b/torch/csrc/api/src/nn/modules/rnn.cpp
@@ -96,14 +96,12 @@
 }
 
 template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::forward(std::vector<Tensor> inputs) {
-  std::vector<Tensor> inp = {inputs[0],
-                             inputs.size() > 1 ? inputs[1] : Tensor()};
-  if (cudnn_mode_.has_value() && at::cudnn_is_acceptable(inp[0]) &&
+RNNOutput RNNImplBase<Derived>::forward(Tensor input, Tensor state) {
+  if (cudnn_mode_.has_value() && at::cudnn_is_acceptable(input) &&
       options_.dropout_ == 0) {
-    return {CUDNN_forward(inp)};
+    return CUDNN_forward(input, state);
   } else {
-    return {autograd_forward(inp)};
+    return autograd_forward(input, state);
   }
 }
 
@@ -122,42 +120,39 @@
 }
 
 template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::autograd_forward(
-    std::vector<Tensor> inputs) {
-  auto inp = inputs[0];
-
-  std::vector<Tensor> state;
-  auto has_hidden = inputs[1].defined();
-  auto layer_dimension = has_hidden ? inputs[1].ndimension() - 3 : -1;
+RNNOutput RNNImplBase<Derived>::autograd_forward(Tensor input, Tensor state) {
+  std::vector<at::Tensor> new_state;
+  auto has_hidden = state.defined();
+  auto layer_dimension = has_hidden ? state.ndimension() - 3 : -1;
   for (int64_t layer = 0; layer < options_.layers_; layer++) {
-    state.push_back(
-        has_hidden ? inputs[1].select(layer_dimension, layer) : Tensor());
+    new_state.push_back(
+        has_hidden ? state.select(layer_dimension, layer) : Tensor());
   }
 
   auto output = torch::zeros(
-      {inp.size(0), inp.size(1), options_.hidden_size_}, inp.options());
-  for (int64_t t = 0; t < inp.size(0); t++) {
-    auto x = inp.select(0, t);
+      {input.size(0), input.size(1), options_.hidden_size_}, input.options());
+  for (int64_t t = 0; t < input.size(0); t++) {
+    auto x = input.select(0, t);
     for (int64_t i = 0; i < options_.layers_; i++) {
       // cell_forward() returns a stacked tensor of one or more cell states.
-      auto layer_output = cell_forward({x, state[i]}, i);
+      auto layer_output = cell_forward(x, new_state[i], i);
       // If there are multiple cell states, keep all. If there is only one,
       // the first dimension will be 1, so `.squeeze(0)` will unpack it.
-      state[i] = layer_output[0].squeeze(0);
+      new_state[i] = layer_output.squeeze(0);
       // x should always be the hidden cell state h, assumed to be the zero-th.
-      x = layer_output[0][0];
+      x = layer_output[0];
       output.select(0, t).copy_(x);
       if (options_.dropout_ > 0 && i != options_.layers_ - 1) {
-        x = dropout_module_->forward({x})[0];
+        x = dropout_module_->forward(x);
       }
     }
   }
 
-  auto state_output = at::stack(TensorListView(state));
+  auto state_output = at::stack(new_state);
   if (has_cell_state_) {
     state_output.transpose_(0, 1);
   }
-  return std::vector<Tensor>({output, state_output});
+  return {output, state_output};
 }
 
 template <typename Derived>
@@ -200,26 +195,26 @@
 }
 
 template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::CUDNN_forward(
-    std::vector<Tensor> inputs) {
-  auto x = inputs[0];
+RNNOutput RNNImplBase<Derived>::CUDNN_forward(Tensor input, Tensor state) {
   Tensor hx, cx;
-  if (inputs[1].defined()) {
+  if (state.defined()) {
     if (has_cell_state_) {
-      hx = inputs[1][0];
-      cx = inputs[1][1];
+      hx = state[0];
+      cx = state[1];
     } else {
-      hx = inputs[1];
+      hx = state;
     }
   } else {
     hx = torch::zeros(
-        {options_.layers_, x.size(1), options_.hidden_size_}, x.options());
+        {options_.layers_, input.size(1), options_.hidden_size_},
+        input.options());
     if (has_cell_state_) {
       cx = torch::zeros(
-          {options_.layers_, x.size(1), options_.hidden_size_}, x.options());
+          {options_.layers_, input.size(1), options_.hidden_size_},
+          input.options());
     }
   }
-  auto dropout_state = torch::empty({}, x.type());
+  auto dropout_state = torch::empty({}, input.type());
 
   std::vector<void*> weight_data_ptrs;
   for (auto& p : this->parameters()) {
@@ -235,7 +230,7 @@
 
   // tup = std::tuple of output, hy, cy, reserve, new_weight_buf
   auto tup = _cudnn_rnn(
-      x,
+      input,
       TensorListView(flat_weights()),
       /*weight_stride=*/options_.with_bias_ ? 4 : 2,
       flat_weights_,
@@ -261,7 +256,7 @@
   }
 
   Tensor output = std::get<0>(tup);
-  return std::vector<Tensor>({output, hidden_output});
+  return {output, hidden_output};
 }
 
 template <typename Derived>
@@ -323,18 +318,15 @@
   }
 }
 
-std::vector<Tensor> RNNImpl::cell_forward(
-    std::vector<Tensor> inputs,
-    int64_t layer) {
-  auto x = inputs[0];
-  auto hx = inputs[1].defined()
-      ? inputs[1]
-      : torch::zeros({x.size(0), options_.hidden_size_}, x.options());
+Tensor RNNImpl::cell_forward(Tensor input, Tensor state, int64_t layer) {
+  auto hx = state.defined()
+      ? state
+      : torch::zeros({input.size(0), options_.hidden_size_}, input.options());
 
-  auto h = linear(x, ihw_[layer], ihb_[layer]) +
+  auto h = linear(input, ihw_[layer], ihb_[layer]) +
       linear(hx, hhw_[layer], hhb_[layer]);
 
-  return {at::stack(TensorListView(activation_function_(h)))};
+  return at::stack(activation_function_(h));
 }
 
 const RNNOptions& RNNImpl::options() const noexcept {
@@ -350,17 +342,15 @@
           /*number_of_gates=*/4,
           /*has_cell_state=*/true) {}
 
-std::vector<Tensor> LSTMImpl::cell_forward(
-    std::vector<Tensor> inputs,
-    int64_t layer) {
-  auto x = inputs[0];
-  auto hid = inputs[1].defined()
-      ? inputs[1]
-      : torch::zeros({2, x.size(0), options_.hidden_size_}, x.options());
+Tensor LSTMImpl::cell_forward(Tensor input, Tensor state, int64_t layer) {
+  auto hid = state.defined()
+      ? state
+      : torch::zeros(
+            {2, input.size(0), options_.hidden_size_}, input.options());
   auto hx = hid[0];
   auto cx = hid[1];
 
-  auto gates = linear(x, ihw_[layer], ihb_[layer]) +
+  auto gates = linear(input, ihw_[layer], ihb_[layer]) +
       linear(hx, hhw_[layer], hhb_[layer]);
 
   auto chunked = gates.chunk(4, 1);
@@ -372,7 +362,7 @@
   auto cy = (forget_gate * cx) + (in_gate * cell_gate);
   auto hy = out_gate * cy.tanh();
 
-  return {at::stack(TensorListView({hy, cy}), 0)};
+  return at::stack(TensorListView{hy, cy}, 0);
 }
 
 const LSTMOptions& LSTMImpl::options() const noexcept {
@@ -387,16 +377,13 @@
           /*cudnn_mode=*/CuDNNMode::GRU,
           /*number_of_gates=*/3) {}
 
-std::vector<Tensor> GRUImpl::cell_forward(
-    std::vector<Tensor> inputs,
-    int64_t layer) {
-  auto x = inputs[0];
-  auto hx = inputs[1].defined()
-      ? inputs[1]
-      : torch::zeros({x.size(0), options_.hidden_size_}, x.options());
+Tensor GRUImpl::cell_forward(Tensor input, Tensor state, int64_t layer) {
+  auto hx = state.defined()
+      ? state
+      : torch::zeros({input.size(0), options_.hidden_size_}, input.options());
 
-  auto gi = linear(x, ihw_[layer], ihb_[layer]);
-  auto gh = linear(x, hhw_[layer], hhb_[layer]);
+  auto gi = linear(input, ihw_[layer], ihb_[layer]);
+  auto gh = linear(input, hhw_[layer], hhb_[layer]);
   auto gic = gi.chunk(3, 1);
   auto ghc = gh.chunk(3, 1);
 
@@ -405,7 +392,7 @@
   auto new_gate = (gic[2] + reset_gate * ghc[2]).tanh_();
   auto hy = new_gate + input_gate * (hx - new_gate);
 
-  return {at::stack(TensorListView(hy))};
+  return at::stack(TensorListView(hy));
 }
 
 const GRUOptions& GRUImpl::options() const noexcept {