[C++ API] Better forward methods (#8739)
* Better forward methods in C++ API
capitalize error message in test_torch.test_flatten
Support for operator()
* Add operator() to Functional
* Get rid of SigmoidLinear
* Add BoundFunction to FunctionalImpl
* Remove macro from conv because it makes errors more nasty
diff --git a/aten/src/ATen/ExpandUtils.cpp b/aten/src/ATen/ExpandUtils.cpp
index 9b6d228..172201a 100644
--- a/aten/src/ATen/ExpandUtils.cpp
+++ b/aten/src/ATen/ExpandUtils.cpp
@@ -14,26 +14,28 @@
long dimB = dimsB - 1 - offset;
long sizeA = (dimA >= 0) ? a[dimA] : 1;
long sizeB = (dimB >= 0) ? b[dimB] : 1;
- if (sizeA == sizeB || sizeA == 1 || sizeB == 1) {
- expandedSizes[i] = std::max(sizeA, sizeB);
- } else {
- std::ostringstream oss;
- oss << "The size of tensor a (" << sizeA << ") must match the size of tensor b ("
- << sizeB << ") at non-singleton dimension " << i;
- throw std::runtime_error(oss.str());
- }
+
+ AT_CHECK(
+ sizeA == sizeB || sizeA == 1 || sizeB == 1,
+ "The size of tensor a (", sizeA,
+ ") must match the size of tensor b (", sizeB,
+ ") at non-singleton dimension ", i);
+
+ expandedSizes[i] = std::max(sizeA, sizeB);
}
return expandedSizes;
}
-std::tuple<std::vector<int64_t>, std::vector<int64_t> >
-inferExpandGeometry(const Tensor &tensor, IntList sizes) {
+std::tuple<std::vector<int64_t>, std::vector<int64_t>> inferExpandGeometry(
+ const Tensor& tensor,
+ IntList sizes) {
int64_t ndim = sizes.size();
if (tensor.dim() == 0) {
std::vector<int64_t> expandedStrides(ndim, 0);
- return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(sizes.vec(), expandedStrides);
+ return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
+ sizes.vec(), expandedStrides);
}
std::vector<int64_t> expandedSizes(ndim);
std::vector<int64_t> expandedStrides(ndim);
@@ -43,34 +45,35 @@
int64_t offset = ndim - 1 - i;
int64_t dim = tensor.dim() - 1 - offset;
int64_t size = (dim >= 0) ? tensor.sizes()[dim] : 1;
- int64_t stride = (dim >= 0) ?
- tensor.strides()[dim] : expandedSizes[i + 1] * expandedStrides[i + 1];
+ int64_t stride = (dim >= 0) ? tensor.strides()[dim]
+ : expandedSizes[i + 1] * expandedStrides[i + 1];
int64_t targetSize = sizes[i];
if (targetSize == -1) {
- if (dim < 0) {
- std::ostringstream oss;
- oss << "The expanded size of the tensor (" << targetSize << ") isn't allowed in a leading, "
- << "non-existing dimension " << i;
- throw std::runtime_error(oss.str());
- } else {
- targetSize = size;
- }
+ AT_CHECK(
+ dim >= 0,
+ "The expanded size of the tensor (",
+ targetSize,
+ ") isn't allowed in a leading, non-existing dimension ",
+ i);
+ targetSize = size;
}
if (size != targetSize) {
- if (size == 1) {
- size = targetSize;
- stride = 0;
- } else {
- std::ostringstream oss;
- oss << "The expanded size of the tensor (" << targetSize << ") must match the existing size (" << size
- << ") at non-singleton dimension " << i;
- throw std::runtime_error(oss.str());
- }
+ AT_CHECK(
+ size == 1,
+ "The expanded size of the tensor (",
+ targetSize,
+ ") must match the existing size (",
+ size,
+ ") at non-singleton dimension ",
+ i);
+ size = targetSize;
+ stride = 0;
}
expandedSizes[i] = size;
expandedStrides[i] = stride;
}
- return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(expandedSizes, expandedStrides);
+ return std::tuple<std::vector<int64_t>, std::vector<int64_t>>(
+ expandedSizes, expandedStrides);
}
-}
+} // namespace at
diff --git a/aten/src/ATen/WrapDimUtils.h b/aten/src/ATen/WrapDimUtils.h
index 953f808..688b2ab 100644
--- a/aten/src/ATen/WrapDimUtils.h
+++ b/aten/src/ATen/WrapDimUtils.h
@@ -17,12 +17,10 @@
int64_t min = -dim_post_expr;
int64_t max = dim_post_expr - 1;
- if (dim < min || dim > max) {
- std::ostringstream oss;
- oss << "dimension out of range (expected to be in range of [" << min
- << ", " << max << "], but got " << dim << ")",
- throw std::runtime_error(oss.str());
- }
+ AT_CHECK(
+ dim >= min && dim <= max,
+ "Dimension out of range (expected to be in range of [",
+ min, ", ", max, "], but got ", dim, ")");
if (dim < 0) dim += dim_post_expr;
return dim;
}
diff --git a/test/cpp/api/integration.cpp b/test/cpp/api/integration.cpp
index 9786fcb..87ed246 100644
--- a/test/cpp/api/integration.cpp
+++ b/test/cpp/api/integration.cpp
@@ -230,7 +230,7 @@
std::cout << "Num correct: " << correct.data().sum().toCFloat() << " out of "
<< telabel.size(0) << std::endl;
return correct.data().sum().toCFloat() > telabel.size(0) * 0.8;
-};
+}
TEST_CASE("integration/cartpole") {
std::cerr << "Training episodic policy gradient with a critic for up to 3000"
@@ -245,16 +245,16 @@
std::vector<torch::Tensor> saved_values;
std::vector<float> rewards;
- auto forward = [&](std::vector<torch::Tensor> inp) {
- auto x = linear->forward(inp)[0].clamp_min(0);
- torch::Tensor actions = policyHead->forward({x})[0];
- torch::Tensor value = valueHead->forward({x})[0];
+ auto forward = [&](torch::Tensor inp) {
+ auto x = linear->forward(inp).clamp_min(0);
+ torch::Tensor actions = policyHead->forward(x);
+ torch::Tensor value = valueHead->forward(x);
return std::make_tuple(at::softmax(actions, -1), value);
};
- auto selectAction = [&](torch::Tensor state) {
- // Only work on single state now, change index to gather for batch
- auto out = forward({state});
+ auto selectAction = [&](at::Tensor state) {
+ // Only work on single state right now, change index to gather for batch
+ auto out = forward(state);
auto probs = torch::Tensor(std::get<0>(out));
auto value = torch::Tensor(std::get<1>(out));
auto action = probs.data().multinomial(1)[0].toCInt();
@@ -340,16 +340,15 @@
auto linear2 = model->add(Linear(50, 10), "linear2");
auto forward = [&](torch::Tensor x) {
- x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2}))
- .clamp_min(0);
- x = conv2->forward({x})[0];
- x = drop2d->forward({x})[0];
+ x = std::get<0>(at::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
+ x = conv2->forward(x);
+ x = drop2d->forward(x);
x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
x = x.view({-1, 320});
- x = linear1->forward({x})[0].clamp_min(0);
- x = drop->forward({x})[0];
- x = linear2->forward({x})[0];
+ x = linear1->forward(x).clamp_min(0);
+ x = drop->forward(x);
+ x = linear2->forward(x);
x = at::log_softmax(x, 1);
return x;
};
@@ -378,16 +377,15 @@
auto linear2 = model->add(Linear(50, 10), "linear2");
auto forward = [&](torch::Tensor x) {
- x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2}))
- .clamp_min(0);
- x = batchnorm2d->forward({x})[0];
- x = conv2->forward({x})[0];
+ x = std::get<0>(at::max_pool2d(conv1->forward(x), {2, 2})).clamp_min(0);
+ x = batchnorm2d->forward(x);
+ x = conv2->forward(x);
x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
x = x.view({-1, 320});
- x = linear1->forward({x})[0].clamp_min(0);
- x = batchnorm1->forward({x})[0];
- x = linear2->forward({x})[0];
+ x = linear1->forward(x).clamp_min(0);
+ x = batchnorm1->forward(x);
+ x = linear2->forward(x);
x = at::log_softmax(x, 1);
return x;
};
diff --git a/test/cpp/api/misc.cpp b/test/cpp/api/misc.cpp
index de4cd03..c2d505d 100644
--- a/test/cpp/api/misc.cpp
+++ b/test/cpp/api/misc.cpp
@@ -22,7 +22,7 @@
torch::NoGradGuard guard;
Linear model(5, 2);
auto x = torch::randn({10, 5}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
diff --git a/test/cpp/api/module.cpp b/test/cpp/api/module.cpp
index 6e0f887..2db70584 100644
--- a/test/cpp/api/module.cpp
+++ b/test/cpp/api/module.cpp
@@ -38,7 +38,7 @@
TEST_CASE("module/zero-grad") {
Linear module(3, 4);
auto weight = torch::ones({8, 3}, at::requires_grad());
- auto loss = module->forward({weight}).front().sum();
+ auto loss = module->forward(weight).sum();
loss.backward();
for (auto& parameter : module->parameters()) {
auto grad = parameter->grad();
diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp
index 80a8cbf..64d6dc3 100644
--- a/test/cpp/api/modules.cpp
+++ b/test/cpp/api/modules.cpp
@@ -21,10 +21,6 @@
l3 = register_module("l3", Linear(5, 100));
}
- std::vector<torch::Tensor> forward(std::vector<torch::Tensor> input) {
- return input;
- }
-
Linear l1, l2, l3;
};
@@ -36,10 +32,6 @@
param_ = register_parameter("param", torch::empty({3, 2, 21}));
}
- std::vector<torch::Tensor> forward(std::vector<torch::Tensor> input) {
- return input;
- };
-
torch::Tensor param_;
Linear l1;
std::shared_ptr<TestModel> t;
@@ -50,7 +42,7 @@
SECTION("1d") {
Conv1d model(Conv1dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -66,7 +58,7 @@
SECTION("even") {
Conv2d model(Conv2dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -82,7 +74,7 @@
SECTION("uneven") {
Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2}));
auto x = torch::randn({2, 3, 5, 4}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -98,7 +90,7 @@
SECTION("3d") {
Conv3d model(Conv3dOptions(3, 2, 3).stride(2));
auto x = torch::randn({2, 3, 5, 5, 5}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -116,7 +108,7 @@
SECTION("basic1") {
Linear model(5, 2);
auto x = torch::randn({10, 5}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -136,9 +128,9 @@
auto l3 = model->add(Linear(5, 100), "l3");
auto x = torch::randn({1000, 10}, at::requires_grad());
- x = l1->forward({x})[0].clamp_min(0);
- x = l2->forward({x})[0].clamp_min(0);
- x = l3->forward({x})[0].clamp_min(0);
+ x = l1->forward(x).clamp_min(0);
+ x = l2->forward(x).clamp_min(0);
+ x = l3->forward(x).clamp_min(0);
x.backward();
REQUIRE(x.ndimension() == 2);
@@ -154,7 +146,7 @@
// Cannot get gradients to change indices (input) - only for embedding
// params
auto x = torch::full({10}, dict_size - 1, torch::kInt64);
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -169,7 +161,7 @@
SECTION("list") {
Embedding model(6, 4);
auto x = torch::full({2, 3}, 5, torch::kInt64);
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -183,7 +175,7 @@
SECTION("dropout") {
Dropout dropout(0.5);
torch::Tensor x = torch::ones(100, at::requires_grad());
- torch::Tensor y = dropout->forward({x})[0];
+ torch::Tensor y = dropout->forward(x);
y.backward();
REQUIRE(y.ndimension() == 1);
@@ -194,7 +186,7 @@
// REQUIRE(y.sum().toCFloat() > 70); // Probably
dropout->eval();
- y = dropout->forward({x})[0];
+ y = dropout->forward(x);
REQUIRE(y.data().sum().toCFloat() == 100);
}
@@ -219,17 +211,31 @@
}
SECTION("functional") {
- bool was_called = false;
- // clang-format off
- auto functional = Functional([&was_called](std::vector<torch::Tensor> input) {
- was_called = true;
- return input;
- });
- // clang-format on
- auto output = functional->forward({torch::ones(5, at::requires_grad())});
- REQUIRE(was_called);
- REQUIRE(output.size() == 1);
- REQUIRE(output.front().equal(torch::ones(5, at::requires_grad())));
+ {
+ bool was_called = false;
+ auto functional = Functional([&was_called](torch::Tensor input) {
+ was_called = true;
+ return input;
+ });
+ auto output = functional->forward(torch::ones(5, at::requires_grad()));
+ REQUIRE(was_called);
+ REQUIRE(output.equal(torch::ones(5, at::requires_grad())));
+
+ was_called = false;
+ output = functional(torch::ones(5, at::requires_grad()));
+ REQUIRE(was_called);
+ REQUIRE(output.equal(torch::ones(5, at::requires_grad())));
+ }
+ {
+ auto functional = Functional(at::relu);
+ REQUIRE(functional(torch::ones({})).data().toCFloat() == 1);
+ REQUIRE(functional(torch::ones({})).toCFloat() == 1);
+ REQUIRE(functional(torch::ones({}) * -1).toCFloat() == 0);
+ }
+ {
+ auto functional = Functional(at::elu, /*alpha=*/1, /*scale=*/0);
+ REQUIRE(functional(torch::ones({})).toCFloat() == 0);
+ }
}
}
@@ -238,7 +244,7 @@
Linear model(5, 2);
model->cuda();
auto x = torch::randn({10, 5}, at::device(at::kCUDA).requires_grad(true));
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
@@ -255,7 +261,7 @@
model->cuda();
model->cpu();
auto x = torch::randn({10, 5}, at::requires_grad());
- auto y = model->forward({x})[0];
+ auto y = model->forward(x);
torch::Tensor s = y.sum();
s.backward();
diff --git a/test/cpp/api/optim.cpp b/test/cpp/api/optim.cpp
index 75ee0f3..e237493 100644
--- a/test/cpp/api/optim.cpp
+++ b/test/cpp/api/optim.cpp
@@ -1,6 +1,7 @@
#include <catch.hpp>
#include <torch/nn/module.h>
+#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/optim.h>
@@ -62,18 +63,22 @@
torch::manual_seed(0);
Sequential model(
- torch::SigmoidLinear(Linear(2, 3)), torch::SigmoidLinear(Linear(3, 1)));
+ Linear(2, 3),
+ Functional(at::sigmoid),
+ Linear(3, 1),
+ Functional(at::sigmoid));
+
model.to(torch::kFloat64);
// Use exact input values because matching random values is hard.
auto parameters = model.parameters();
- parameters.at("0.linear.weight").data().flatten() = at::tensor(
+ parameters.at("0.weight").data().flatten() = at::tensor(
{-0.2109, -0.4976, -0.1413, -0.3420, -0.2524, 0.6976}, torch::kFloat64);
- parameters.at("0.linear.bias").data() =
+ parameters.at("0.bias").data() =
at::tensor({-0.1085, -0.2979, 0.6892}, torch::kFloat64);
- parameters.at("1.linear.weight").data().flatten() =
+ parameters.at("2.weight").data().flatten() =
at::tensor({-0.0508, -0.3941, -0.2843}, torch::kFloat64);
- parameters.at("1.linear.bias").data() =
+ parameters.at("2.bias").data() =
at::tensor({-0.0711}, torch::kFloat64);
auto optimizer = OptimizerClass(parameters, options);
@@ -111,7 +116,10 @@
std::srand(0);
torch::manual_seed(0);
Sequential model(
- torch::SigmoidLinear(Linear(2, 8)), torch::SigmoidLinear(Linear(8, 1)));
+ Linear(2, 8),
+ Functional(at::sigmoid),
+ Linear(8, 1),
+ Functional(at::sigmoid));
SECTION("sgd") {
REQUIRE(test_optimizer_xor(
@@ -195,7 +203,7 @@
REQUIRE(!parameter->grad().defined());
}
- auto output = model->forward({torch::ones({5, 2})}).front();
+ auto output = model->forward(torch::ones({5, 2}));
auto loss = output.sum();
loss.backward();
diff --git a/test/cpp/api/rnn.cpp b/test/cpp/api/rnn.cpp
index 20a64f7..f45754a 100644
--- a/test/cpp/api/rnn.cpp
+++ b/test/cpp/api/rnn.cpp
@@ -22,9 +22,9 @@
auto T = x.size(0);
auto B = x.size(1);
x = x.view({T * B, 1});
- x = l1->forward({x})[0].view({T, B, nhid}).tanh_();
- x = rnn->forward({x})[0][T - 1];
- x = lo->forward({x})[0];
+ x = l1->forward(x).view({T, B, nhid}).tanh_();
+ x = rnn->forward(x).output[T - 1];
+ x = lo->forward(x);
return x;
};
@@ -61,26 +61,23 @@
return true;
};
-void check_lstm_sizes(std::vector<torch::Tensor> tup) {
+void check_lstm_sizes(RNNOutput output) {
// Expect the LSTM to have 64 outputs and 3 layers, with an input of batch
// 10 and 16 time steps (10 x 16 x n)
- auto out = tup[0];
- auto hids = tup[1];
+ REQUIRE(output.output.ndimension() == 3);
+ REQUIRE(output.output.size(0) == 10);
+ REQUIRE(output.output.size(1) == 16);
+ REQUIRE(output.output.size(2) == 64);
- REQUIRE(out.ndimension() == 3);
- REQUIRE(out.size(0) == 10);
- REQUIRE(out.size(1) == 16);
- REQUIRE(out.size(2) == 64);
-
- REQUIRE(hids.ndimension() == 4);
- REQUIRE(hids.size(0) == 2); // (hx, cx)
- REQUIRE(hids.size(1) == 3); // layers
- REQUIRE(hids.size(2) == 16); // Batchsize
- REQUIRE(hids.size(3) == 64); // 64 hidden dims
+ REQUIRE(output.state.ndimension() == 4);
+ REQUIRE(output.state.size(0) == 2); // (hx, cx)
+ REQUIRE(output.state.size(1) == 3); // layers
+ REQUIRE(output.state.size(2) == 16); // Batchsize
+ REQUIRE(output.state.size(3) == 64); // 64 hidden dims
// Something is in the hiddens
- REQUIRE(hids.norm().toCFloat() > 0);
+ REQUIRE(output.state.norm().toCFloat() > 0);
}
TEST_CASE("rnn") {
@@ -88,17 +85,17 @@
SECTION("sizes") {
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
auto x = torch::randn({10, 16, 128}, at::requires_grad());
- auto tup = model->forward({x});
+ auto output = model->forward(x);
auto y = x.mean();
y.backward();
- check_lstm_sizes(tup);
+ check_lstm_sizes(output);
- auto next = model->forward({x, tup[1]});
+ auto next = model->forward(x, output.state);
check_lstm_sizes(next);
- torch::Tensor diff = next[1] - tup[1];
+ torch::Tensor diff = next.state - output.state;
// Hiddens changed
REQUIRE(diff.data().abs().sum().toCFloat() > 1e-3);
@@ -122,13 +119,13 @@
p[i] = (size - i) / size;
}
- auto out = model->forward({x});
- REQUIRE(out[0].ndimension() == 3);
- REQUIRE(out[0].size(0) == 3);
- REQUIRE(out[0].size(1) == 4);
- REQUIRE(out[0].size(2) == 2);
+ auto out = model->forward(x);
+ REQUIRE(out.output.ndimension() == 3);
+ REQUIRE(out.output.size(0) == 3);
+ REQUIRE(out.output.size(1) == 4);
+ REQUIRE(out.output.size(2) == 2);
- auto flat = out[0].data().view(3 * 4 * 2);
+ auto flat = out.output.data().view(3 * 4 * 2);
float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239,
0.4183, 0.5147, 0.6822, 0.8064, 0.6726, 0.7968,
0.6620, 0.7860, 0.6501, 0.7741, 0.7889, 0.9003,
@@ -137,12 +134,12 @@
REQUIRE(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
}
- REQUIRE(out[1].ndimension() == 4); // (hx, cx) x layers x B x 2
- REQUIRE(out[1].size(0) == 2);
- REQUIRE(out[1].size(1) == 1);
- REQUIRE(out[1].size(2) == 4);
- REQUIRE(out[1].size(3) == 2);
- flat = out[1].data().view(16);
+ REQUIRE(out.state.ndimension() == 4); // (hx, cx) x layers x B x 2
+ REQUIRE(out.state.size(0) == 2);
+ REQUIRE(out.state.size(1) == 1);
+ REQUIRE(out.state.size(2) == 4);
+ REQUIRE(out.state.size(3) == 2);
+ flat = out.state.data().view(16);
float h_out[] = {0.7889,
0.9003,
0.7769,
@@ -192,17 +189,17 @@
LSTM model(LSTMOptions(128, 64).layers(3).dropout(0.2));
model->cuda();
auto x = torch::randn({10, 16, 128}, at::requires_grad().device(at::kCUDA));
- auto tup = model->forward({x});
+ auto output = model->forward(x);
auto y = x.mean();
y.backward();
- check_lstm_sizes(tup);
+ check_lstm_sizes(output);
- auto next = model->forward({x, tup[1]});
+ auto next = model->forward(x, output.state);
check_lstm_sizes(next);
- torch::Tensor diff = next[1] - tup[1];
+ torch::Tensor diff = next.state - output.state;
// Hiddens changed
REQUIRE(diff.data().abs().sum().toCFloat() > 1e-3);
diff --git a/test/cpp/api/sequential.cpp b/test/cpp/api/sequential.cpp
index a55eede..9c79a93 100644
--- a/test/cpp/api/sequential.cpp
+++ b/test/cpp/api/sequential.cpp
@@ -1,5 +1,6 @@
#include <catch.hpp>
+#include <torch/nn/modules.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/tensor.h>
@@ -177,12 +178,19 @@
Sequential sequential(Linear(10, 3), Linear(3, 5), Linear(5, 100));
auto x = torch::randn({1000, 10}, at::requires_grad());
- auto y =
- sequential
- .forward<std::vector<torch::Tensor>>(std::vector<torch::Tensor>{x})
- .front();
+ auto y = sequential.forward(x);
REQUIRE(y.ndimension() == 2);
REQUIRE(y.size(0) == 1000);
REQUIRE(y.size(1) == 100);
}
+
+ SECTION("can hold other important modules") {
+ Sequential sequential(
+ Linear(10, 3),
+ Conv2d(1, 2, 3),
+ Dropout(0.5),
+ BatchNorm(5),
+ Embedding(4, 10),
+ LSTM(4, 5));
+ }
}
diff --git a/test/cpp/api/serialization.cpp b/test/cpp/api/serialization.cpp
index f997a03..9aaa327 100644
--- a/test/cpp/api/serialization.cpp
+++ b/test/cpp/api/serialization.cpp
@@ -1,5 +1,6 @@
#include <catch.hpp>
+#include <torch/nn/modules/functional.h>
#include <torch/nn/modules/linear.h>
#include <torch/nn/modules/sequential.h>
#include <torch/optim/optimizer.h>
@@ -21,7 +22,10 @@
namespace {
std::shared_ptr<Sequential> xor_model() {
return std::make_shared<Sequential>(
- torch::SigmoidLinear(2, 8), torch::SigmoidLinear(8, 1));
+ Linear(2, 8),
+ Functional(at::sigmoid),
+ Linear(8, 1),
+ Functional(at::sigmoid));
}
} // namespace
@@ -244,7 +248,7 @@
auto step = [&](torch::optim::Optimizer& optimizer, Linear model) {
optimizer.zero_grad();
- auto y = model->forward({x})[0].sum();
+ auto y = model->forward(x).sum();
y.backward();
optimizer.step();
};
diff --git a/test/cpp/api/util.h b/test/cpp/api/util.h
index f204797..32b6221 100644
--- a/test/cpp/api/util.h
+++ b/test/cpp/api/util.h
@@ -1,8 +1,7 @@
-#include <torch/nn/module.h>
-#include <torch/nn/modules/linear.h>
+#pragma once
-#include <memory>
-#include <stdexcept>
+#include <torch/nn/cloneable.h>
+
#include <string>
#include <utility>
@@ -12,12 +11,6 @@
// for experimental implementations
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
public:
- virtual std::vector<Tensor> forward(std::vector<Tensor>) {
- throw std::runtime_error(
- "SimpleContainer has no forward, maybe you"
- " wanted to subclass and override this function?");
- }
-
void reset() override {}
template <typename ModuleHolder>
@@ -27,19 +20,4 @@
return Module::register_module(std::move(name), module_holder);
}
};
-
-struct SigmoidLinear : nn::Module {
- SigmoidLinear(int64_t in, int64_t out) : linear(nn::Linear(in, out)) {
- register_module("linear", linear);
- }
-
- explicit SigmoidLinear(nn::Linear linear_) : linear(std::move(linear_)) {
- register_module("linear", linear);
- }
- Tensor forward(Tensor input) {
- return linear->forward({input}).front().sigmoid();
- }
- nn::Linear linear;
-};
-
} // namespace torch
diff --git a/test/test_torch.py b/test/test_torch.py
index a1c84c8..edc3f85 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -5526,7 +5526,7 @@
self.assertEqual(flat, src)
# out of bounds index
- with self.assertRaisesRegex(RuntimeError, 'dimension out of range'):
+ with self.assertRaisesRegex(RuntimeError, 'Dimension out of range'):
src.flatten(5, 10)
# invalid start and end
diff --git a/torch/csrc/api/include/torch/expanding_array.h b/torch/csrc/api/include/torch/expanding_array.h
index f03005a..24f6977 100644
--- a/torch/csrc/api/include/torch/expanding_array.h
+++ b/torch/csrc/api/include/torch/expanding_array.h
@@ -70,7 +70,7 @@
}
/// Returns an `ArrayRef` to the underlying `std::array`.
- operator at::ArrayRef<T>() {
+ operator at::ArrayRef<T>() const {
return values_;
}
diff --git a/torch/csrc/api/include/torch/nn/modules/batchnorm.h b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
index c8949b1..0b662bb 100644
--- a/torch/csrc/api/include/torch/nn/modules/batchnorm.h
+++ b/torch/csrc/api/include/torch/nn/modules/batchnorm.h
@@ -22,7 +22,10 @@
explicit BatchNormImpl(BatchNormOptions options);
void reset() override;
- std::vector<Tensor> forward(std::vector<Tensor>);
+
+ Tensor forward(Tensor input);
+ Tensor pure_forward(Tensor input, Tensor mean, Tensor variance);
+
const BatchNormOptions& options() const noexcept;
private:
diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h
index fab17c2..f7a2dd1 100644
--- a/torch/csrc/api/include/torch/nn/modules/conv.h
+++ b/torch/csrc/api/include/torch/nn/modules/conv.h
@@ -43,20 +43,35 @@
ConvOptions<D> options_;
};
-#define CONV_D(D) \
- class Conv##D##dImpl : public ConvImpl<D, Conv##D##dImpl> { \
- public: \
- using ConvImpl<D, Conv##D##dImpl>::ConvImpl; \
- std::vector<Tensor> forward(std::vector<Tensor> input); \
- }; \
- using Conv##D##dOptions = ConvOptions<D>; \
- TORCH_MODULE(Conv##D##d)
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv1d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-CONV_D(1);
-CONV_D(2);
-CONV_D(3);
+class Conv1dImpl : public ConvImpl<1, Conv1dImpl> {
+ public:
+ using ConvImpl<1, Conv1dImpl>::ConvImpl;
+ Tensor forward(Tensor input);
+};
+using Conv1dOptions = ConvOptions<1>;
+TORCH_MODULE(Conv1d);
-#undef CONV_D
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+class Conv2dImpl : public ConvImpl<2, Conv2dImpl> {
+ public:
+ using ConvImpl<2, Conv2dImpl>::ConvImpl;
+ Tensor forward(Tensor input);
+};
+using Conv2dOptions = ConvOptions<2>;
+TORCH_MODULE(Conv2d);
+
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Conv3d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+class Conv3dImpl : public ConvImpl<3, Conv3dImpl> {
+ public:
+ using ConvImpl<3, Conv3dImpl>::ConvImpl;
+ Tensor forward(Tensor input);
+};
+using Conv3dOptions = ConvOptions<3>;
+TORCH_MODULE(Conv3d);
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/include/torch/nn/modules/dropout.h b/torch/csrc/api/include/torch/nn/modules/dropout.h
index 901ee3b..5a5762e 100644
--- a/torch/csrc/api/include/torch/nn/modules/dropout.h
+++ b/torch/csrc/api/include/torch/nn/modules/dropout.h
@@ -21,7 +21,7 @@
explicit DropoutImplBase(DropoutOptions options);
void reset() override;
- std::vector<Tensor> forward(std::vector<Tensor> input);
+ Tensor forward(Tensor input);
const DropoutOptions& options() const noexcept;
protected:
diff --git a/torch/csrc/api/include/torch/nn/modules/embedding.h b/torch/csrc/api/include/torch/nn/modules/embedding.h
index fcedb72..c345ca4 100644
--- a/torch/csrc/api/include/torch/nn/modules/embedding.h
+++ b/torch/csrc/api/include/torch/nn/modules/embedding.h
@@ -21,7 +21,7 @@
explicit EmbeddingImpl(EmbeddingOptions options);
void reset() override;
- std::vector<Tensor> forward(std::vector<Tensor>);
+ Tensor forward(Tensor);
const EmbeddingOptions& options() const noexcept;
private:
diff --git a/torch/csrc/api/include/torch/nn/modules/functional.h b/torch/csrc/api/include/torch/nn/modules/functional.h
index 21a42b9..e3a9d68 100644
--- a/torch/csrc/api/include/torch/nn/modules/functional.h
+++ b/torch/csrc/api/include/torch/nn/modules/functional.h
@@ -1,11 +1,11 @@
#pragma once
+#include <torch/csrc/utils/variadic.h>
#include <torch/nn/cloneable.h>
#include <torch/nn/pimpl.h>
#include <torch/tensor.h>
#include <functional>
-#include <vector>
namespace torch {
namespace nn {
@@ -14,13 +14,41 @@
// Sequential.
class FunctionalImpl : public torch::nn::Cloneable<FunctionalImpl> {
public:
- using Function = std::function<std::vector<Tensor>(std::vector<Tensor>)>;
+ using Function = std::function<Tensor(Tensor)>;
+
+ /// A small type that is used only in the constructor of `FunctionalImpl`,
+ /// that allows constructing it with a function with more than one argument,
+ /// and binding all but the first parameter to specific values. It is
+ /// necessary due to interaction with the `ModuleHolder` class, which expects
+ /// to construct a module with `Module({...})` when there is more than one
+ /// argument. It also deals with argument binding.
+ struct BoundFunction {
+ template <
+ typename AnyFunction,
+ typename... Args,
+ typename = torch::enable_if_t<(sizeof...(Args) > 0)>>
+ /* implicit */ BoundFunction(AnyFunction original_function, Args&&... args)
+ : function_(std::bind(
+ original_function,
+ /*input=*/std::placeholders::_1,
+ std::forward<Args>(args)...)) {
+ // std::bind is normally evil, but (1) gcc is broken w.r.t. handling
+ // parameter pack expansion in lambdas and (2) moving parameter packs into
+ // a lambda only works with C++14, so std::bind is the more move-aware
+ // solution here.
+ }
+
+ Function function_;
+ };
explicit FunctionalImpl(Function function);
- explicit FunctionalImpl(std::function<Tensor(Tensor)> function);
+ explicit FunctionalImpl(BoundFunction bound_function);
void reset() override;
- std::vector<Tensor> forward(std::vector<Tensor> input);
+ Tensor forward(Tensor input);
+
+ /// Calls forward(input).
+ Tensor operator()(Tensor input);
private:
Function function_;
diff --git a/torch/csrc/api/include/torch/nn/modules/linear.h b/torch/csrc/api/include/torch/nn/modules/linear.h
index 607fadd..612a2e4 100644
--- a/torch/csrc/api/include/torch/nn/modules/linear.h
+++ b/torch/csrc/api/include/torch/nn/modules/linear.h
@@ -17,12 +17,12 @@
TORCH_ARG(bool, with_bias) = true;
};
-class LinearImpl : public torch::nn::Cloneable<LinearImpl> {
+class LinearImpl : public Cloneable<LinearImpl> {
public:
explicit LinearImpl(LinearOptions options);
void reset() override;
- std::vector<Tensor> forward(std::vector<Tensor>);
+ Tensor forward(Tensor);
const LinearOptions& options() const noexcept;
private:
diff --git a/torch/csrc/api/include/torch/nn/modules/rnn.h b/torch/csrc/api/include/torch/nn/modules/rnn.h
index 2fa7aa2..7fc8a2a8 100644
--- a/torch/csrc/api/include/torch/nn/modules/rnn.h
+++ b/torch/csrc/api/include/torch/nn/modules/rnn.h
@@ -16,6 +16,12 @@
namespace torch {
namespace nn {
+
+struct RNNOutput {
+ Tensor output;
+ Tensor state;
+};
+
namespace detail {
struct RNNOptionsBase {
RNNOptionsBase(int64_t input_size, int64_t hidden_size);
@@ -40,7 +46,7 @@
int64_t number_of_gates = 1,
bool has_cell_state = false);
- std::vector<Tensor> forward(std::vector<Tensor>);
+ RNNOutput forward(Tensor input, Tensor state = {});
void reset() override;
@@ -55,12 +61,10 @@
void flatten_parameters_for_cudnn();
protected:
- virtual std::vector<Tensor> cell_forward(
- std::vector<Tensor>,
- int64_t layer) = 0;
+ virtual Tensor cell_forward(Tensor input, Tensor state, int64_t layer) = 0;
- std::vector<Tensor> CUDNN_forward(std::vector<Tensor>);
- std::vector<Tensor> autograd_forward(std::vector<Tensor>);
+ RNNOutput CUDNN_forward(Tensor input, Tensor state);
+ RNNOutput autograd_forward(Tensor input, Tensor state);
std::vector<Tensor> flat_weights() const;
@@ -114,7 +118,7 @@
const RNNOptions& options() const noexcept;
private:
- std::vector<Tensor> cell_forward(std::vector<Tensor>, int64_t layer) override;
+ Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
RNNOptions options_;
std::function<Tensor(Tensor)> activation_function_;
@@ -133,7 +137,7 @@
const LSTMOptions& options() const noexcept;
private:
- std::vector<Tensor> cell_forward(std::vector<Tensor>, int64_t layer) override;
+ Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
};
TORCH_MODULE(LSTM);
@@ -149,7 +153,7 @@
const GRUOptions& options() const noexcept;
private:
- std::vector<Tensor> cell_forward(std::vector<Tensor>, int64_t layer) override;
+ Tensor cell_forward(Tensor input, Tensor state, int64_t layer) override;
};
TORCH_MODULE(GRU);
diff --git a/torch/csrc/api/include/torch/nn/pimpl.h b/torch/csrc/api/include/torch/nn/pimpl.h
index 5f772c5..30bfcd4 100644
--- a/torch/csrc/api/include/torch/nn/pimpl.h
+++ b/torch/csrc/api/include/torch/nn/pimpl.h
@@ -1,6 +1,7 @@
#pragma once
#include <torch/csrc/utils/variadic.h>
+#include <torch/tensor.h>
#include <memory>
#include <type_traits>
@@ -71,6 +72,12 @@
return impl_.get();
}
+ /// Forwards to the call operator of the contained module.
+ template <typename... Args>
+ Tensor operator()(Args&&... args) {
+ return (*impl_)(std::forward<Args>(args)...);
+ }
+
/// Returns the underlying module.
const std::shared_ptr<Contained>& get() const {
AT_CHECK(!is_empty(), "Accessing empty ModuleHolder");
diff --git a/torch/csrc/api/src/nn/modules/batchnorm.cpp b/torch/csrc/api/src/nn/modules/batchnorm.cpp
index 267a09e..6de6efa 100644
--- a/torch/csrc/api/src/nn/modules/batchnorm.cpp
+++ b/torch/csrc/api/src/nn/modules/batchnorm.cpp
@@ -33,11 +33,13 @@
}
}
-std::vector<Tensor> BatchNormImpl::forward(std::vector<Tensor> inputs) {
- auto& input = inputs[0];
- auto& running_mean_ = (options_.stateful_ ? this->running_mean_ : inputs[1]);
- auto& running_variance_ =
- (options_.stateful_ ? this->running_variance_ : inputs[2]);
+Tensor BatchNormImpl::forward(Tensor input) {
+ return pure_forward(input, Tensor(), Tensor());
+}
+
+Tensor BatchNormImpl::pure_forward(Tensor input, Tensor mean, Tensor variance) {
+ auto& running_mean = options_.stateful_ ? running_mean_ : mean;
+ auto& running_variance = options_.stateful_ ? running_variance_ : variance;
if (is_training()) {
const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
@@ -46,18 +48,16 @@
"BatchNorm expected more than 1 value per channel when training!");
}
- auto output = at::batch_norm(
+ return at::batch_norm(
input,
weight_,
bias_,
- running_mean_,
- running_variance_,
+ running_mean,
+ running_variance,
is_training(),
options_.momentum_,
options_.eps_,
torch::cuda::cudnn_is_available());
-
- return std::vector<Tensor>({output});
}
const BatchNormOptions& BatchNormImpl::options() const noexcept {
diff --git a/torch/csrc/api/src/nn/modules/conv.cpp b/torch/csrc/api/src/nn/modules/conv.cpp
index 1efd0ad..36ffb83 100644
--- a/torch/csrc/api/src/nn/modules/conv.cpp
+++ b/torch/csrc/api/src/nn/modules/conv.cpp
@@ -70,88 +70,87 @@
return options_;
}
-std::vector<Tensor> Conv1dImpl::forward(std::vector<Tensor> input) {
- AT_ASSERT(input.front().ndimension() == 3);
+Tensor Conv1dImpl::forward(Tensor input) {
+ AT_ASSERT(input.ndimension() == 3);
if (options_.transposed_) {
- return {at::conv_transpose1d(
- input.front(),
+ return at::conv_transpose1d(
+ input,
weight_,
bias_,
options_.stride_,
options_.padding_,
options_.output_padding_,
options_.groups_,
- options_.dilation_)};
+ options_.dilation_);
}
- return {at::conv1d(
- input.front(),
+ return at::conv1d(
+ input,
weight_,
bias_,
options_.stride_,
options_.padding_,
options_.dilation_,
- options_.groups_)};
+ options_.groups_);
}
-std::vector<Tensor> Conv2dImpl::forward(std::vector<Tensor> input) {
- AT_ASSERT(input.front().ndimension() == 4);
+Tensor Conv2dImpl::forward(Tensor input) {
+ AT_ASSERT(input.ndimension() == 4);
if (options_.transposed_) {
- return {at::conv_transpose2d(
- input.front(),
+ return at::conv_transpose2d(
+ input,
weight_,
bias_,
options_.stride_,
options_.padding_,
options_.output_padding_,
options_.groups_,
- options_.dilation_)};
+ options_.dilation_);
}
- return {at::conv2d(
- input.front(),
+ return at::conv2d(
+ input,
weight_,
bias_,
options_.stride_,
options_.padding_,
options_.dilation_,
- options_.groups_)};
+ options_.groups_);
}
-std::vector<Tensor> Conv3dImpl::forward(std::vector<Tensor> input) {
- AT_ASSERT(input.front().ndimension() == 5);
+Tensor Conv3dImpl::forward(Tensor input) {
+ AT_ASSERT(input.ndimension() == 5);
if (options_.transposed_) {
- return {at::conv_transpose3d(
- input.front(),
+ return at::conv_transpose3d(
+ input,
weight_,
bias_,
options_.stride_,
options_.padding_,
options_.output_padding_,
options_.groups_,
- options_.dilation_)};
+ options_.dilation_);
} else {
- return {at::conv3d(
- input.front(),
+ return at::conv3d(
+ input,
weight_,
bias_,
options_.stride_,
options_.padding_,
options_.dilation_,
- options_.groups_)};
+ options_.groups_);
}
}
-#define CONV_D(D) \
- template struct ConvOptions<D>; \
- template class ConvImpl<D, Conv##D##dImpl>
+template struct ConvOptions<1>;
+template class ConvImpl<1, Conv1dImpl>;
-CONV_D(1);
-CONV_D(2);
-CONV_D(3);
+template struct ConvOptions<2>;
+template class ConvImpl<2, Conv2dImpl>;
-#undef CONV_D
+template struct ConvOptions<3>;
+template class ConvImpl<3, Conv3dImpl>;
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/dropout.cpp b/torch/csrc/api/src/nn/modules/dropout.cpp
index ee5006e..2b544b2 100644
--- a/torch/csrc/api/src/nn/modules/dropout.cpp
+++ b/torch/csrc/api/src/nn/modules/dropout.cpp
@@ -21,19 +21,16 @@
void DropoutImplBase<Derived>::reset() {}
template <typename Derived>
-std::vector<Tensor> DropoutImplBase<Derived>::forward(
- std::vector<Tensor> input) {
+Tensor DropoutImplBase<Derived>::forward(Tensor input) {
if (options_.rate_ == 0 || !this->is_training()) {
return input;
}
- std::vector<Tensor> output;
- for (const auto& value : input) {
- const auto noise = (noise_mask(value).uniform_(0, 1) > options_.rate_)
- .toType(value.type().scalarType())
- .mul_(1.0f / (1.0f - options_.rate_));
- output.push_back(value * noise);
- }
- return output;
+
+ auto scale = 1.0f / (1.0f - options_.rate_);
+ auto boolean_mask = noise_mask(input).uniform_(0, 1) > options_.rate_;
+ auto noise = boolean_mask.to(input.dtype()).mul_(scale);
+
+ return input * noise;
}
template <typename Derived>
diff --git a/torch/csrc/api/src/nn/modules/embedding.cpp b/torch/csrc/api/src/nn/modules/embedding.cpp
index f69c8a5..93cb876 100644
--- a/torch/csrc/api/src/nn/modules/embedding.cpp
+++ b/torch/csrc/api/src/nn/modules/embedding.cpp
@@ -23,8 +23,8 @@
table_.data().normal_(0, 1);
}
-std::vector<Tensor> EmbeddingImpl::forward(std::vector<Tensor> input) {
- return {at::embedding(table_, /*indices=*/input[0])};
+Tensor EmbeddingImpl::forward(Tensor input) {
+ return at::embedding(table_, /*indices=*/input);
}
const EmbeddingOptions& EmbeddingImpl::options() const noexcept {
diff --git a/torch/csrc/api/src/nn/modules/functional.cpp b/torch/csrc/api/src/nn/modules/functional.cpp
index 5245adf..d4a4cc3 100644
--- a/torch/csrc/api/src/nn/modules/functional.cpp
+++ b/torch/csrc/api/src/nn/modules/functional.cpp
@@ -4,22 +4,23 @@
#include <functional>
#include <utility>
-#include <vector>
namespace torch {
namespace nn {
-FunctionalImpl::FunctionalImpl(Function function)
+FunctionalImpl::FunctionalImpl(std::function<Tensor(Tensor)> function)
: function_(std::move(function)) {}
-FunctionalImpl::FunctionalImpl(std::function<Tensor(Tensor)> function)
- : function_([function](std::vector<Tensor> input) {
- return std::vector<Tensor>({function(input.front())});
- }) {}
+FunctionalImpl::FunctionalImpl(BoundFunction bound_function)
+ : function_(std::move(bound_function.function_)) {}
void FunctionalImpl::reset() {}
-std::vector<Tensor> FunctionalImpl::forward(std::vector<Tensor> input) {
+Tensor FunctionalImpl::forward(Tensor input) {
return function_(input);
}
+
+Tensor FunctionalImpl::operator()(Tensor input) {
+ return forward(input);
+}
} // namespace nn
} // namespace torch
diff --git a/torch/csrc/api/src/nn/modules/linear.cpp b/torch/csrc/api/src/nn/modules/linear.cpp
index d1178c8..065c842 100644
--- a/torch/csrc/api/src/nn/modules/linear.cpp
+++ b/torch/csrc/api/src/nn/modules/linear.cpp
@@ -26,19 +26,18 @@
}
}
-std::vector<Tensor> LinearImpl::forward(std::vector<Tensor> input) {
- auto x = input[0];
- if (x.ndimension() == 2 && options_.with_bias_) {
+Tensor LinearImpl::forward(Tensor input) {
+ if (input.ndimension() == 2 && options_.with_bias_) {
// Fused op is marginally faster
- AT_ASSERT(x.size(1) == weight_.size(1));
- return {at::addmm(bias_, x, weight_.t())};
+ AT_ASSERT(input.size(1) == weight_.size(1));
+ return {at::addmm(bias_, input, weight_.t())};
}
- auto output = x.matmul(weight_.t());
+ auto output = input.matmul(weight_.t());
if (options_.with_bias_) {
output += bias_;
}
- return {output};
+ return output;
}
const LinearOptions& LinearImpl::options() const noexcept {
diff --git a/torch/csrc/api/src/nn/modules/rnn.cpp b/torch/csrc/api/src/nn/modules/rnn.cpp
index 1a5f367..88a7595 100644
--- a/torch/csrc/api/src/nn/modules/rnn.cpp
+++ b/torch/csrc/api/src/nn/modules/rnn.cpp
@@ -96,14 +96,12 @@
}
template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::forward(std::vector<Tensor> inputs) {
- std::vector<Tensor> inp = {inputs[0],
- inputs.size() > 1 ? inputs[1] : Tensor()};
- if (cudnn_mode_.has_value() && at::cudnn_is_acceptable(inp[0]) &&
+RNNOutput RNNImplBase<Derived>::forward(Tensor input, Tensor state) {
+ if (cudnn_mode_.has_value() && at::cudnn_is_acceptable(input) &&
options_.dropout_ == 0) {
- return {CUDNN_forward(inp)};
+ return CUDNN_forward(input, state);
} else {
- return {autograd_forward(inp)};
+ return autograd_forward(input, state);
}
}
@@ -122,42 +120,39 @@
}
template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::autograd_forward(
- std::vector<Tensor> inputs) {
- auto inp = inputs[0];
-
- std::vector<Tensor> state;
- auto has_hidden = inputs[1].defined();
- auto layer_dimension = has_hidden ? inputs[1].ndimension() - 3 : -1;
+RNNOutput RNNImplBase<Derived>::autograd_forward(Tensor input, Tensor state) {
+ std::vector<at::Tensor> new_state;
+ auto has_hidden = state.defined();
+ auto layer_dimension = has_hidden ? state.ndimension() - 3 : -1;
for (int64_t layer = 0; layer < options_.layers_; layer++) {
- state.push_back(
- has_hidden ? inputs[1].select(layer_dimension, layer) : Tensor());
+ new_state.push_back(
+ has_hidden ? state.select(layer_dimension, layer) : Tensor());
}
auto output = torch::zeros(
- {inp.size(0), inp.size(1), options_.hidden_size_}, inp.options());
- for (int64_t t = 0; t < inp.size(0); t++) {
- auto x = inp.select(0, t);
+ {input.size(0), input.size(1), options_.hidden_size_}, input.options());
+ for (int64_t t = 0; t < input.size(0); t++) {
+ auto x = input.select(0, t);
for (int64_t i = 0; i < options_.layers_; i++) {
// cell_forward() returns a stacked tensor of one or more cell states.
- auto layer_output = cell_forward({x, state[i]}, i);
+ auto layer_output = cell_forward(x, new_state[i], i);
// If there are multiple cell states, keep all. If there is only one,
// the first dimension will be 1, so `.squeeze(0)` will unpack it.
- state[i] = layer_output[0].squeeze(0);
+ new_state[i] = layer_output.squeeze(0);
// x should always be the hidden cell state h, assumed to be the zero-th.
- x = layer_output[0][0];
+ x = layer_output[0];
output.select(0, t).copy_(x);
if (options_.dropout_ > 0 && i != options_.layers_ - 1) {
- x = dropout_module_->forward({x})[0];
+ x = dropout_module_->forward(x);
}
}
}
- auto state_output = at::stack(TensorListView(state));
+ auto state_output = at::stack(new_state);
if (has_cell_state_) {
state_output.transpose_(0, 1);
}
- return std::vector<Tensor>({output, state_output});
+ return {output, state_output};
}
template <typename Derived>
@@ -200,26 +195,26 @@
}
template <typename Derived>
-std::vector<Tensor> RNNImplBase<Derived>::CUDNN_forward(
- std::vector<Tensor> inputs) {
- auto x = inputs[0];
+RNNOutput RNNImplBase<Derived>::CUDNN_forward(Tensor input, Tensor state) {
Tensor hx, cx;
- if (inputs[1].defined()) {
+ if (state.defined()) {
if (has_cell_state_) {
- hx = inputs[1][0];
- cx = inputs[1][1];
+ hx = state[0];
+ cx = state[1];
} else {
- hx = inputs[1];
+ hx = state;
}
} else {
hx = torch::zeros(
- {options_.layers_, x.size(1), options_.hidden_size_}, x.options());
+ {options_.layers_, input.size(1), options_.hidden_size_},
+ input.options());
if (has_cell_state_) {
cx = torch::zeros(
- {options_.layers_, x.size(1), options_.hidden_size_}, x.options());
+ {options_.layers_, input.size(1), options_.hidden_size_},
+ input.options());
}
}
- auto dropout_state = torch::empty({}, x.type());
+ auto dropout_state = torch::empty({}, input.type());
std::vector<void*> weight_data_ptrs;
for (auto& p : this->parameters()) {
@@ -235,7 +230,7 @@
// tup = std::tuple of output, hy, cy, reserve, new_weight_buf
auto tup = _cudnn_rnn(
- x,
+ input,
TensorListView(flat_weights()),
/*weight_stride=*/options_.with_bias_ ? 4 : 2,
flat_weights_,
@@ -261,7 +256,7 @@
}
Tensor output = std::get<0>(tup);
- return std::vector<Tensor>({output, hidden_output});
+ return {output, hidden_output};
}
template <typename Derived>
@@ -323,18 +318,15 @@
}
}
-std::vector<Tensor> RNNImpl::cell_forward(
- std::vector<Tensor> inputs,
- int64_t layer) {
- auto x = inputs[0];
- auto hx = inputs[1].defined()
- ? inputs[1]
- : torch::zeros({x.size(0), options_.hidden_size_}, x.options());
+Tensor RNNImpl::cell_forward(Tensor input, Tensor state, int64_t layer) {
+ auto hx = state.defined()
+ ? state
+ : torch::zeros({input.size(0), options_.hidden_size_}, input.options());
- auto h = linear(x, ihw_[layer], ihb_[layer]) +
+ auto h = linear(input, ihw_[layer], ihb_[layer]) +
linear(hx, hhw_[layer], hhb_[layer]);
- return {at::stack(TensorListView(activation_function_(h)))};
+ return at::stack(activation_function_(h));
}
const RNNOptions& RNNImpl::options() const noexcept {
@@ -350,17 +342,15 @@
/*number_of_gates=*/4,
/*has_cell_state=*/true) {}
-std::vector<Tensor> LSTMImpl::cell_forward(
- std::vector<Tensor> inputs,
- int64_t layer) {
- auto x = inputs[0];
- auto hid = inputs[1].defined()
- ? inputs[1]
- : torch::zeros({2, x.size(0), options_.hidden_size_}, x.options());
+Tensor LSTMImpl::cell_forward(Tensor input, Tensor state, int64_t layer) {
+ auto hid = state.defined()
+ ? state
+ : torch::zeros(
+ {2, input.size(0), options_.hidden_size_}, input.options());
auto hx = hid[0];
auto cx = hid[1];
- auto gates = linear(x, ihw_[layer], ihb_[layer]) +
+ auto gates = linear(input, ihw_[layer], ihb_[layer]) +
linear(hx, hhw_[layer], hhb_[layer]);
auto chunked = gates.chunk(4, 1);
@@ -372,7 +362,7 @@
auto cy = (forget_gate * cx) + (in_gate * cell_gate);
auto hy = out_gate * cy.tanh();
- return {at::stack(TensorListView({hy, cy}), 0)};
+ return at::stack(TensorListView{hy, cy}, 0);
}
const LSTMOptions& LSTMImpl::options() const noexcept {
@@ -387,16 +377,13 @@
/*cudnn_mode=*/CuDNNMode::GRU,
/*number_of_gates=*/3) {}
-std::vector<Tensor> GRUImpl::cell_forward(
- std::vector<Tensor> inputs,
- int64_t layer) {
- auto x = inputs[0];
- auto hx = inputs[1].defined()
- ? inputs[1]
- : torch::zeros({x.size(0), options_.hidden_size_}, x.options());
+Tensor GRUImpl::cell_forward(Tensor input, Tensor state, int64_t layer) {
+ auto hx = state.defined()
+ ? state
+ : torch::zeros({input.size(0), options_.hidden_size_}, input.options());
- auto gi = linear(x, ihw_[layer], ihb_[layer]);
- auto gh = linear(x, hhw_[layer], hhb_[layer]);
+ auto gi = linear(input, ihw_[layer], ihb_[layer]);
+ auto gh = linear(input, hhw_[layer], hhb_[layer]);
auto gic = gi.chunk(3, 1);
auto ghc = gh.chunk(3, 1);
@@ -405,7 +392,7 @@
auto new_gate = (gic[2] + reset_gate * ghc[2]).tanh_();
auto hy = new_gate + input_gate * (hx - new_gate);
- return {at::stack(TensorListView(hy))};
+ return at::stack(TensorListView(hy));
}
const GRUOptions& GRUImpl::options() const noexcept {