Merge autogradpp into PyTorch (#7074)
* Dump autogradpp into PyTorch
* Fixed up CMake for autogradpp/C++ API
* Made cereal a submodule
* Change search location of autogradpps mnist directory
* Add test_api to CI
* Download MNIST from the internet instead of storing in repo
* Fix warnings
diff --git a/.gitignore b/.gitignore
index 1f4b2c7..6d6e1a4 100644
--- a/.gitignore
+++ b/.gitignore
@@ -55,6 +55,7 @@
test/data/legacy_serialized.pt
test/data/linear.pt
.mypy_cache
+test/cpp/api/mnist
# IPython notebook checkpoints
.ipynb_checkpoints
diff --git a/.gitmodules b/.gitmodules
index da502fa..c73454e 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -77,3 +77,6 @@
[submodule "third_party/ideep"]
path = third_party/ideep
url = https://github.com/intel/ideep.git
+[submodule "third_party/cereal"]
+ path = third_party/cereal
+ url = https://github.com/USCiLab/cereal
diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh
index 8bf47d0..d0b6ab5 100755
--- a/.jenkins/pytorch/test.sh
+++ b/.jenkins/pytorch/test.sh
@@ -55,4 +55,6 @@
else
"$CPP_BUILD"/libtorch/bin/test_jit "[cpu]"
fi
+ python tools/download_mnist.py -d test/cpp/api/mnist
+ "$CPP_BUILD"/libtorch/bin/test_api
fi
diff --git a/test/cpp/api/container_t.cpp b/test/cpp/api/container_t.cpp
new file mode 100644
index 0000000..323d37d
--- /dev/null
+++ b/test/cpp/api/container_t.cpp
@@ -0,0 +1,258 @@
+#include "test.h"
+
+AUTOGRAD_CONTAINER_CLASS(TestModel) {
+ public:
+ void initialize_containers() override {
+ add(Linear(10, 3).make(), "l1");
+ add(Linear(3, 5).make(), "l2");
+ add(Linear(5, 100).make(), "l3");
+ }
+
+ variable_list forward(variable_list input) override { return input; };
+};
+
+AUTOGRAD_CONTAINER_CLASS(NestedModel) {
+ public:
+ void initialize_containers() override {
+ add(Linear(5, 20).make(), "l1");
+ add(TestModel().make(), "test");
+ }
+
+ void initialize_parameters() override {
+ add(Var(DefaultTensor(at::kFloat).tensor({3, 2, 21}), false), "param");
+ }
+
+ variable_list forward(variable_list input) override { return input; };
+};
+
+CASE("containers/conv2d/even") {
+ auto model = Conv2d(3, 2, 3).stride(2).make();
+ auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 4);
+ EXPECT(s.ndimension() == 0);
+ for (auto i = 0; i < 4; i++) {
+ EXPECT(y.size(i) == 2);
+ }
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3);
+};
+
+CASE("containers/conv2d/uneven") {
+ auto model = Conv2d(3, 2, IntVec({3, 2})).stride(2).make();
+ auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5, 4}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 4);
+ EXPECT(s.ndimension() == 0);
+ for (auto i = 0; i < 4; i++) {
+ EXPECT(y.size(i) == 2);
+ }
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 2);
+};
+
+CASE("containers/conv1d/even") {
+ auto model = Conv1d(3, 2, 3).stride(2).make();
+ auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 4);
+ EXPECT(s.ndimension() == 0);
+ for (auto i = 0; i < 3; i++) {
+ EXPECT(y.size(i) == 2);
+ }
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3);
+};
+
+CASE("containers/conv3d/even") {
+ auto model = Conv3d(3, 2, 3).stride(2).make();
+ auto x = Var(at::CPU(at::kFloat).randn({2, 3, 5, 5, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 5);
+ EXPECT(s.ndimension() == 0);
+ for (auto i = 0; i < 5; i++) {
+ EXPECT(y.size(i) == 2);
+ }
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 3 * 2 * 3 * 3 * 3);
+};
+
+CASE("containers/linear/basic1") {
+ auto model = Linear(5, 2).make();
+ auto x = Var(at::CPU(at::kFloat).randn({10, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 2);
+ EXPECT(s.ndimension() == 0);
+ EXPECT(y.size(0) == 10);
+ EXPECT(y.size(1) == 2);
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 2 * 5);
+};
+
+CASE("containers/linear/sequential") {
+ auto model = ContainerList()
+ .append(Linear(10, 3).make())
+ .append(Linear(3, 5).make())
+ .append(Linear(5, 100).make())
+ .make();
+
+ auto x = Var(at::CPU(at::kFloat).randn({1000, 10}));
+ for (auto layer : *model) {
+ x = layer->forward({x})[0];
+ x = x.clamp_min(0); // relu
+ }
+
+ backward(x);
+ EXPECT(x.ndimension() == 2);
+ EXPECT(x.size(0) == 1000);
+ EXPECT(x.size(1) == 100);
+ EXPECT(x.data().min().toCFloat() == 0);
+};
+
+CASE("containers/linear/simple") {
+ auto model = SimpleContainer().make();
+ auto l1 = model->add(Linear(10, 3).make(), "l1");
+ auto l2 = model->add(Linear(3, 5).make(), "l2");
+ auto l3 = model->add(Linear(5, 100).make(), "l3");
+
+ auto x = Var(at::CPU(at::kFloat).randn({1000, 10}));
+ x = l1->forward({x})[0].clamp_min(0);
+ x = l2->forward({x})[0].clamp_min(0);
+ x = l3->forward({x})[0].clamp_min(0);
+
+ backward(x);
+ EXPECT(x.ndimension() == 2);
+ EXPECT(x.size(0) == 1000);
+ EXPECT(x.size(1) == 100);
+ EXPECT(x.data().min().toCFloat() == 0);
+};
+
+CASE("containers/clone") {
+ auto model = TestModel().make();
+
+ auto model2 = model->clone();
+ auto m1param = model->parameters();
+ auto m2param = model2->parameters();
+ for (auto& param : m1param) {
+ EXPECT(param.second.allclose(m2param[param.first]));
+ param.second.data().mul_(2);
+ }
+ for (auto& param : m1param) {
+ EXPECT(!param.second.allclose(m2param[param.first]));
+ }
+};
+
+CASE("containers/embedding/basic") {
+ int dict_size = 10;
+ auto model = Embedding(dict_size, 2).make();
+ // Cannot get gradients to change indices (input) - only for embedding params
+ auto x = Var(at::CPU(at::kLong).tensor({10}).fill_(dict_size - 1), false);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 2);
+ EXPECT(s.ndimension() == 0);
+ EXPECT(y.size(0) == 10);
+ EXPECT(y.size(1) == 2);
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 2 * dict_size);
+};
+
+CASE("containers/embedding/list") {
+ auto model = Embedding(6, 4).make();
+ auto x = Var(at::CPU(at::kLong).tensor({2, 3}).fill_(5), false);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 3);
+ EXPECT(y.size(0) == 2);
+ EXPECT(y.size(1) == 3);
+ EXPECT(y.size(2) == 4);
+};
+
+CASE("containers/cuda/1") {
+ CUDA_GUARD;
+ auto model = Linear(5, 2).make();
+ model->cuda();
+ auto x = Var(at::CUDA(at::kFloat).randn({10, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 2);
+ EXPECT(s.ndimension() == 0);
+ EXPECT(y.size(0) == 10);
+ EXPECT(y.size(1) == 2);
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 2 * 5);
+};
+
+CASE("containers/cuda/2") {
+ CUDA_GUARD;
+ auto model = Linear(5, 2).make();
+ model->cuda();
+ model->cpu();
+ auto x = Var(at::CPU(at::kFloat).randn({10, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(y.ndimension() == 2);
+ EXPECT(s.ndimension() == 0);
+ EXPECT(y.size(0) == 10);
+ EXPECT(y.size(1) == 2);
+
+ EXPECT(model->parameters()["weight"].grad().numel() == 2 * 5);
+};
+
+CASE("containers/dropout/1") {
+ auto dropout = Dropout(0.5).make();
+ Variable x = Var(at::CPU(at::kFloat).ones(100));
+ Variable y = dropout->forward({x})[0];
+
+ backward(y);
+ EXPECT(y.ndimension() == 1);
+ EXPECT(y.size(0) == 100);
+ EXPECT(y.sum().toCFloat() < 130); // Probably
+ EXPECT(y.sum().toCFloat() > 70); // Probably
+
+ dropout->eval();
+ y = dropout->forward({x})[0];
+ EXPECT(y.data().sum().toCFloat() == 100);
+};
+
+CASE("containers/param") {
+ auto model = NestedModel().make();
+ EXPECT(model->param("param").size(0) == 3);
+ EXPECT(model->param("param").size(1) == 2);
+ EXPECT(model->param("param").size(2) == 21);
+ EXPECT(model->param("l1.bias").size(0) == 20);
+ EXPECT(model->param("l1.weight").size(0) == 20);
+ EXPECT(model->param("l1.weight").size(1) == 5);
+ EXPECT(model->param("test.l1.bias").size(0) == 3);
+ EXPECT(model->param("test.l1.weight").size(0) == 3);
+ EXPECT(model->param("test.l1.weight").size(1) == 10);
+ EXPECT(model->param("test.l2.bias").size(0) == 5);
+ EXPECT(model->param("test.l2.weight").size(0) == 5);
+ EXPECT(model->param("test.l2.weight").size(1) == 3);
+ EXPECT(model->param("test.l3.bias").size(0) == 100);
+ EXPECT(model->param("test.l3.weight").size(0) == 100);
+ EXPECT(model->param("test.l3.weight").size(1) == 5);
+}
diff --git a/test/cpp/api/integration_t.cpp b/test/cpp/api/integration_t.cpp
new file mode 100644
index 0000000..83d3ad5
--- /dev/null
+++ b/test/cpp/api/integration_t.cpp
@@ -0,0 +1,355 @@
+#include "test.h"
+
+class CartPole {
+ // Translated from openai/gym's cartpole.py
+ public:
+ double gravity = 9.8;
+ double masscart = 1.0;
+ double masspole = 0.1;
+ double total_mass = (masspole + masscart);
+ double length = 0.5; // actually half the pole's length;
+ double polemass_length = (masspole * length);
+ double force_mag = 10.0;
+ double tau = 0.02; // seconds between state updates;
+
+ // Angle at which to fail the episode
+ double theta_threshold_radians = 12 * 2 * M_PI / 360;
+ double x_threshold = 2.4;
+ int steps_beyond_done = -1;
+
+ at::Tensor state;
+ double reward;
+ bool done;
+ int step_ = 0;
+
+ at::Tensor getState() {
+ return state;
+ }
+
+ double getReward() {
+ return reward;
+ }
+
+ double isDone() {
+ return done;
+ }
+
+ void reset() {
+ state = at::CPU(at::kFloat).tensor({4}).uniform_(-0.05, 0.05);
+ steps_beyond_done = -1;
+ step_ = 0;
+ }
+
+ CartPole() {
+ reset();
+ }
+
+ void step(int action) {
+ auto x = state[0].toCFloat();
+ auto x_dot = state[1].toCFloat();
+ auto theta = state[2].toCFloat();
+ auto theta_dot = state[3].toCFloat();
+
+ auto force = (action == 1) ? force_mag : -force_mag;
+ auto costheta = std::cos(theta);
+ auto sintheta = std::sin(theta);
+ auto temp = (force + polemass_length * theta_dot * theta_dot * sintheta) / total_mass;
+ auto thetaacc = (gravity * sintheta - costheta* temp) / (length * (4.0/3.0 - masspole * costheta * costheta / total_mass));
+ auto xacc = temp - polemass_length * thetaacc * costheta / total_mass;
+
+ x = x + tau * x_dot;
+ x_dot = x_dot + tau * xacc;
+ theta = theta + tau * theta_dot;
+ theta_dot = theta_dot + tau * thetaacc;
+ state[0] = x;
+ state[1] = x_dot;
+ state[2] = theta;
+ state[3] = theta_dot;
+ done = x < - x_threshold
+ || x > x_threshold
+ || theta < -theta_threshold_radians
+ || theta > theta_threshold_radians
+ || step_ > 200;
+
+ if (!done) {
+ reward = 1.0;
+ } else if (steps_beyond_done == -1) {
+ // Pole just fell!
+ steps_beyond_done = 0;
+ reward = 0;
+ } else {
+ if (steps_beyond_done == 0) {
+ assert(false); // Can't do this
+ }
+ }
+ step_++;
+
+ };
+};
+
+template <typename M, typename F, typename O>
+bool test_mnist(uint32_t batch_size, uint32_t num_epochs, bool useGPU,
+ M&& model, F&& forward_op, O&& optim) {
+ std::cout << "Training MNIST for " << num_epochs << " epochs, rest your eyes for a bit!\n";
+ struct MNIST_Reader
+ {
+ FILE *fp_;
+
+ MNIST_Reader(const char *path) {
+ fp_ = fopen(path, "rb");
+ if (!fp_) throw std::runtime_error("failed to open file");
+ }
+
+ ~MNIST_Reader() { if (fp_) fclose(fp_); }
+
+ int32_t read_int() {
+ uint8_t buf[4];
+ if (fread(buf, sizeof(buf), 1, fp_) != 1) throw std::runtime_error("failed to read an integer");
+ return int32_t(buf[0] << 24 | buf[1] << 16 | buf[2] << 8 | buf[3]);
+ }
+
+ uint8_t read_byte() {
+ uint8_t i;
+ if (fread(&i, sizeof(i), 1, fp_) != 1) throw std::runtime_error("failed to read an byte");
+ return i;
+ }
+ };
+
+ auto readData = [&](std::string fn) {
+ MNIST_Reader rd(fn.c_str());
+
+ /* int image_magic = */ rd.read_int();
+ int image_count = rd.read_int();
+ int image_rows = rd.read_int();
+ int image_cols = rd.read_int();
+
+ auto data = at::CPU(at::kFloat).tensor({image_count, 1, image_rows, image_cols});
+ auto a_data = data.accessor<float, 4>();
+
+ for (int c = 0; c < image_count; c++) {
+ for (int i = 0; i < image_rows; i++) {
+ for (int j = 0; j < image_cols; j++) {
+ a_data[c][0][i][j] = float(rd.read_byte()) / 255;
+ }
+ }
+ }
+
+ return data.toBackend(useGPU ? at::kCUDA : at::kCPU);
+ };
+
+ auto readLabels = [&](std::string fn) {
+ MNIST_Reader rd(fn.c_str());
+ /* int label_magic = */ rd.read_int();
+ int label_count = rd.read_int();
+
+ auto data = at::CPU(at::kLong).tensor({label_count});
+ auto a_data = data.accessor<int64_t, 1>();
+
+ for (int i = 0; i < label_count; ++i) {
+ a_data[i] = long(rd.read_byte());
+ }
+ return data.toBackend(useGPU ? at::kCUDA : at::kCPU);
+ };
+
+ auto trdata = readData("test/cpp/api/mnist/train-images-idx3-ubyte");
+ auto trlabel = readLabels("test/cpp/api/mnist/train-labels-idx1-ubyte");
+ auto tedata = readData("test/cpp/api/mnist/t10k-images-idx3-ubyte");
+ auto telabel = readLabels("test/cpp/api/mnist/t10k-labels-idx1-ubyte");
+
+ if (useGPU) {
+ model->cuda();
+ }
+
+ for (auto epoch = 0U; epoch < num_epochs; epoch++) {
+ auto shuffled_inds = std::vector<int>(trdata.size(0));
+ for (int i=0; i < trdata.size(0); i++) {
+ shuffled_inds[i] = i;
+ }
+ std::random_shuffle(shuffled_inds.begin(), shuffled_inds.end());
+
+ auto inp = (useGPU ? at::CUDA : at::CPU)(at::kFloat).tensor({batch_size, 1, trdata.size(2), trdata.size(3)});
+ auto lab = (useGPU ? at::CUDA : at::CPU)(at::kLong).tensor({batch_size});
+ for (auto p = 0U; p < shuffled_inds.size() - batch_size; p++) {
+ inp[p % batch_size] = trdata[shuffled_inds[p]];
+ lab[p % batch_size] = trlabel[shuffled_inds[p]];
+
+ if (p % batch_size != batch_size - 1) continue;
+ Variable x = forward_op(Var(inp));
+ Variable y = Var(lab, false);
+ Variable loss = at::nll_loss(x, y);
+
+ optim->zero_grad();
+ backward(loss);
+ optim->step();
+ }
+ }
+
+ no_grad_guard guard;
+ auto result = std::get<1>(forward_op(Var(tedata, false)).max(1));
+ Variable correct = (result == Var(telabel)).toType(at::kFloat);
+ std::cout << "Num correct: " << correct.data().sum().toCFloat()
+ << " out of " << telabel.size(0) << std::endl;
+ return correct.data().sum().toCFloat() > telabel.size(0) * 0.8;
+ };
+
+CASE("integration/RL/cartpole") {
+ std::cout << "Training episodic policy gradient with a critic for up to 3000"
+ " episodes, rest your eyes for a bit!\n";
+ auto model = SimpleContainer().make();
+ auto linear = model->add(Linear(4, 128).make(), "linear");
+ auto policyHead = model->add(Linear(128, 2).make(), "policy");
+ auto valueHead = model->add(Linear(128, 1).make(), "action");
+ auto optim = Adam(model, 1e-3).make();
+
+ std::vector<Variable> saved_log_probs;
+ std::vector<Variable> saved_values;
+ std::vector<float> rewards;
+
+ auto forward = [&](variable_list inp) {
+ auto x = linear->forward(inp)[0].clamp_min(0);
+ Variable actions = policyHead->forward({x})[0];
+ Variable value = valueHead->forward({x})[0];
+ return std::make_tuple(at::softmax(actions, -1), value);
+ };
+
+ auto selectAction = [&](at::Tensor state) {
+ // Only work on single state right now, change index to gather for batch
+ auto out = forward({Var(state, false)});
+ auto probs = Variable(std::get<0>(out));
+ auto value = Variable(std::get<1>(out));
+ auto action = probs.data().multinomial(1)[0].toCInt();
+ // Compute the log prob of a multinomial distribution.
+ // This should probably be actually implemented in autogradpp...
+ auto p = probs / probs.sum(-1, true);
+ auto log_prob = p[action].log();
+ saved_log_probs.push_back(log_prob);
+ saved_values.push_back(value);
+ return action;
+ };
+
+ auto finishEpisode = [&]() {
+ auto R = 0.;
+ for (int i = rewards.size() - 1; i >= 0; i--) {
+ R = rewards[i] + 0.99 * R;
+ rewards[i] = R;
+ }
+ auto r_t = at::CPU(at::kFloat).tensorFromBlob(rewards.data(), {static_cast<int64_t>(rewards.size())});
+ r_t = (r_t - r_t.mean()) / (r_t.std() + 1e-5);
+
+ std::vector<at::Tensor> policy_loss;
+ std::vector<at::Tensor> value_loss;
+ for (auto i = 0U; i < saved_log_probs.size(); i++) {
+ auto r = rewards[i] - saved_values[i].toCFloat();
+ policy_loss.push_back(- r * saved_log_probs[i]);
+ value_loss.push_back(at::smooth_l1_loss(saved_values[i], Var(at::CPU(at::kFloat).scalarTensor(at::Scalar(rewards[i])), false)));
+ }
+ auto loss = at::stack(policy_loss).sum() + at::stack(value_loss).sum();
+
+ optim->zero_grad();
+ backward(loss);
+ optim->step();
+
+ rewards.clear();
+ saved_log_probs.clear();
+ saved_values.clear();
+ };
+
+ auto env = CartPole();
+ double running_reward = 10.0;
+ for (auto episode = 0; ; episode++) {
+ env.reset();
+ auto state = env.getState();
+ int t = 0;
+ for ( ; t < 10000; t++) {
+ auto action = selectAction(state);
+ env.step(action);
+ state = env.getState();
+ auto reward = env.getReward();
+ auto done = env.isDone();
+
+ rewards.push_back(reward);
+ if (done) break;
+ }
+
+ running_reward = running_reward * 0.99 + t * 0.01;
+ finishEpisode();
+ /*
+ if (episode % 10 == 0) {
+ printf("Episode %i\tLast length: %5d\tAverage length: %.2f\n",
+ episode, t, running_reward);
+ }
+ */
+ if (running_reward > 150) break;
+ EXPECT(episode < 3000);
+ }
+}
+
+CASE("integration/mnist") { // ~ will make it run last :D
+ CUDA_GUARD;
+ auto model = SimpleContainer().make();
+ auto conv1 = model->add(Conv2d(1, 10, 5).make(), "conv1");
+ auto conv2 = model->add(Conv2d(10, 20, 5).make(), "conv2");
+ auto drop = Dropout(0.3).make();
+ auto drop2d = Dropout2d(0.3).make();
+ auto linear1 = model->add(Linear(320, 50).make(), "linear1");
+ auto linear2 = model->add(Linear(50, 10).make(), "linear2");
+
+ auto forward = [&](Variable x) {
+ x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2})).clamp_min(0);
+ x = conv2->forward({x})[0];
+ x = drop2d->forward({x})[0];
+ x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
+
+ x = x.view({-1, 320});
+ x = linear1->forward({x})[0].clamp_min(0);
+ x = drop->forward({x})[0];
+ x = linear2->forward({x})[0];
+ x = at::log_softmax(x, 1);
+ return x;
+ };
+
+ auto optim = SGD(model, 1e-2).momentum(0.5).make();
+
+ EXPECT(test_mnist(
+ 32, // batch_size
+ 3, // num_epochs
+ true, // useGPU
+ model, forward, optim));
+};
+
+CASE("integration/mnist_batchnorm") { // ~ will make it run last :D
+ CUDA_GUARD;
+ auto model = SimpleContainer().make();
+ auto conv1 = model->add(Conv2d(1, 10, 5).make(), "conv1");
+ auto batchnorm2d = model->add(
+ BatchNorm(10).stateful().make(),
+ "batchnorm2d");
+ auto conv2 = model->add(Conv2d(10, 20, 5).make(), "conv2");
+ auto linear1 = model->add(Linear(320, 50).make(), "linear1");
+ auto batchnorm1 = model->add(
+ BatchNorm(50).stateful().make(),
+ "batchnorm1");
+ auto linear2 = model->add(Linear(50, 10).make(), "linear2");
+
+ auto forward = [&](Variable x) {
+ x = std::get<0>(at::max_pool2d(conv1->forward({x})[0], {2, 2})).clamp_min(0);
+ x = batchnorm2d->forward({x})[0];
+ x = conv2->forward({x})[0];
+ x = std::get<0>(at::max_pool2d(x, {2, 2})).clamp_min(0);
+
+ x = x.view({-1, 320});
+ x = linear1->forward({x})[0].clamp_min(0);
+ x = batchnorm1->forward({x})[0];
+ x = linear2->forward({x})[0];
+ x = at::log_softmax(x, 1);
+ return x;
+ };
+
+ auto optim = SGD(model, 1e-2).momentum(0.5).make();
+
+ EXPECT(test_mnist(
+ 32, // batch_size
+ 3, // num_epochs
+ true, // useGPU
+ model, forward, optim));
+};
diff --git a/test/cpp/api/lest.hpp b/test/cpp/api/lest.hpp
new file mode 100644
index 0000000..525be38
--- /dev/null
+++ b/test/cpp/api/lest.hpp
@@ -0,0 +1,1308 @@
+// Copyright 2013-2018 by Martin Moene
+//
+// lest is based on ideas by Kevlin Henney, see video at
+// http://skillsmatter.com/podcast/agile-testing/kevlin-henney-rethinking-unit-testing-in-c-plus-plus
+//
+// Distributed under the Boost Software License, Version 1.0. (See accompanying
+// file LICENSE.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
+
+#ifndef LEST_LEST_HPP_INCLUDED
+#define LEST_LEST_HPP_INCLUDED
+
+#include <algorithm>
+#include <chrono>
+#include <functional>
+#include <iomanip>
+#include <iostream>
+#include <iterator>
+#include <limits>
+#include <random>
+#include <sstream>
+#include <stdexcept>
+#include <string>
+#include <set>
+#include <tuple>
+#include <typeinfo>
+#include <type_traits>
+#include <utility>
+#include <vector>
+
+#include <cctype>
+#include <cmath>
+#include <cstddef>
+
+#ifdef __clang__
+# pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments"
+# pragma clang diagnostic ignored "-Woverloaded-shift-op-parentheses"
+# pragma clang diagnostic ignored "-Wunused-comparison"
+# pragma clang diagnostic ignored "-Wunused-value"
+#elif defined __GNUC__
+# pragma GCC diagnostic ignored "-Wunused-value"
+#endif
+
+#define lest_VERSION "1.32.0"
+
+#ifndef lest_FEATURE_AUTO_REGISTER
+# define lest_FEATURE_AUTO_REGISTER 0
+#endif
+
+#ifndef lest_FEATURE_COLOURISE
+# define lest_FEATURE_COLOURISE 0
+#endif
+
+#ifndef lest_FEATURE_LITERAL_SUFFIX
+# define lest_FEATURE_LITERAL_SUFFIX 0
+#endif
+
+#ifndef lest_FEATURE_REGEX_SEARCH
+# define lest_FEATURE_REGEX_SEARCH 0
+#endif
+
+#ifndef lest_FEATURE_TIME_PRECISION
+#define lest_FEATURE_TIME_PRECISION 0
+#endif
+
+#ifndef lest_FEATURE_WSTRING
+#define lest_FEATURE_WSTRING 1
+#endif
+
+#ifdef lest_FEATURE_RTTI
+# define lest__cpp_rtti lest_FEATURE_RTTI
+#elif defined(__cpp_rtti)
+# define lest__cpp_rtti __cpp_rtti
+#elif defined(__GXX_RTTI) || defined (_CPPRTTI)
+# define lest__cpp_rtti 1
+#else
+# define lest__cpp_rtti 0
+#endif
+
+#if lest_FEATURE_REGEX_SEARCH
+# include <regex>
+#endif
+
+#if ! defined( lest_NO_SHORT_MACRO_NAMES ) && ! defined( lest_NO_SHORT_ASSERTION_NAMES )
+# define MODULE lest_MODULE
+
+# if ! lest_FEATURE_AUTO_REGISTER
+# define CASE lest_CASE
+# define SCENARIO lest_SCENARIO
+# endif
+
+# define SETUP lest_SETUP
+# define SECTION lest_SECTION
+
+# define EXPECT lest_EXPECT
+# define EXPECT_NOT lest_EXPECT_NOT
+# define EXPECT_NO_THROW lest_EXPECT_NO_THROW
+# define EXPECT_THROWS lest_EXPECT_THROWS
+# define EXPECT_THROWS_AS lest_EXPECT_THROWS_AS
+
+# define GIVEN lest_GIVEN
+# define WHEN lest_WHEN
+# define THEN lest_THEN
+# define AND_WHEN lest_AND_WHEN
+# define AND_THEN lest_AND_THEN
+#endif
+
+#if lest_FEATURE_AUTO_REGISTER
+#define lest_SCENARIO( specification, sketch ) lest_CASE( specification, lest::text("Scenario: ") + sketch )
+#else
+#define lest_SCENARIO( sketch ) lest_CASE( lest::text("Scenario: ") + sketch )
+#endif
+#define lest_GIVEN( context ) lest_SETUP( lest::text( "Given: ") + context )
+#define lest_WHEN( story ) lest_SECTION( lest::text( " When: ") + story )
+#define lest_THEN( story ) lest_SECTION( lest::text( " Then: ") + story )
+#define lest_AND_WHEN( story ) lest_SECTION( lest::text( " And: ") + story )
+#define lest_AND_THEN( story ) lest_SECTION( lest::text( " And: ") + story )
+
+#if lest_FEATURE_AUTO_REGISTER
+
+# define lest_CASE( specification, proposition ) \
+ static void lest_FUNCTION( lest::env & ); \
+ namespace { lest::add_test lest_REGISTRAR( specification, lest::test( proposition, lest_FUNCTION ) ); } \
+ static void lest_FUNCTION( lest::env & lest_env )
+
+#else // lest_FEATURE_AUTO_REGISTER
+
+# define lest_CASE( proposition, ... ) \
+ proposition, [__VA_ARGS__]( lest::env & lest_env )
+
+# define lest_MODULE( specification, module ) \
+ namespace { lest::add_module _( specification, module ); }
+
+#endif //lest_FEATURE_AUTO_REGISTER
+
+#define lest_SETUP( context ) \
+ for ( int lest__section = 0, lest__count = 1; lest__section < lest__count; lest__count -= 0==lest__section++ )
+
+#define lest_SECTION( proposition ) \
+ static int lest_UNIQUE( id ) = 0; \
+ if ( lest::guard( lest_UNIQUE( id ), lest__section, lest__count ) ) \
+ for ( int lest__section = 0, lest__count = 1; lest__section < lest__count; lest__count -= 0==lest__section++ )
+
+#define lest_EXPECT( expr ) \
+ do { \
+ try \
+ { \
+ if ( lest::result score = lest_DECOMPOSE( expr ) ) \
+ throw lest::failure{ lest_LOCATION, #expr, score.decomposition }; \
+ else if ( lest_env.pass ) \
+ lest::report( lest_env.os, lest::passing{ lest_LOCATION, #expr, score.decomposition }, lest_env.testing ); \
+ } \
+ catch(...) \
+ { \
+ lest::inform( lest_LOCATION, #expr ); \
+ } \
+ } while ( lest::is_false() )
+
+#define lest_EXPECT_NOT( expr ) \
+ do { \
+ try \
+ { \
+ if ( lest::result score = lest_DECOMPOSE( expr ) ) \
+ { \
+ if ( lest_env.pass ) \
+ lest::report( lest_env.os, lest::passing{ lest_LOCATION, lest::not_expr( #expr ), lest::not_expr( score.decomposition ) }, lest_env.testing ); \
+ } \
+ else \
+ throw lest::failure{ lest_LOCATION, lest::not_expr( #expr ), lest::not_expr( score.decomposition ) }; \
+ } \
+ catch(...) \
+ { \
+ lest::inform( lest_LOCATION, lest::not_expr( #expr ) ); \
+ } \
+ } while ( lest::is_false() )
+
+#define lest_EXPECT_NO_THROW( expr ) \
+ do \
+ { \
+ try \
+ { \
+ expr; \
+ } \
+ catch (...) \
+ { \
+ lest::inform( lest_LOCATION, #expr ); \
+ } \
+ if ( lest_env.pass ) \
+ lest::report( lest_env.os, lest::got_none( lest_LOCATION, #expr ), lest_env.testing ); \
+ } while ( lest::is_false() )
+
+#define lest_EXPECT_THROWS( expr ) \
+ do \
+ { \
+ try \
+ { \
+ expr; \
+ } \
+ catch (...) \
+ { \
+ if ( lest_env.pass ) \
+ lest::report( lest_env.os, lest::got{ lest_LOCATION, #expr }, lest_env.testing ); \
+ break; \
+ } \
+ throw lest::expected{ lest_LOCATION, #expr }; \
+ } \
+ while ( lest::is_false() )
+
+#define lest_EXPECT_THROWS_AS( expr, excpt ) \
+ do \
+ { \
+ try \
+ { \
+ expr; \
+ } \
+ catch ( excpt & ) \
+ { \
+ if ( lest_env.pass ) \
+ lest::report( lest_env.os, lest::got{ lest_LOCATION, #expr, lest::of_type( #excpt ) }, lest_env.testing ); \
+ break; \
+ } \
+ catch (...) {} \
+ throw lest::expected{ lest_LOCATION, #expr, lest::of_type( #excpt ) }; \
+ } \
+ while ( lest::is_false() )
+
+#define lest_UNIQUE( name ) lest_UNIQUE2( name, __LINE__ )
+#define lest_UNIQUE2( name, line ) lest_UNIQUE3( name, line )
+#define lest_UNIQUE3( name, line ) name ## line
+
+#define lest_DECOMPOSE( expr ) ( lest::expression_decomposer() << expr )
+
+#define lest_FUNCTION lest_UNIQUE(__lest_function__ )
+#define lest_REGISTRAR lest_UNIQUE(__lest_registrar__ )
+
+#define lest_LOCATION lest::location{__FILE__, __LINE__}
+
+namespace lest {
+
+using text = std::string;
+using texts = std::vector<text>;
+
+struct env;
+
+struct test
+{
+ text name;
+ std::function<void( env & )> behaviour;
+
+#if lest_FEATURE_AUTO_REGISTER
+ test( text name, std::function<void( env & )> behaviour )
+ : name( name ), behaviour( behaviour ) {}
+#endif
+};
+
+using tests = std::vector<test>;
+
+#if lest_FEATURE_AUTO_REGISTER
+
+struct add_test
+{
+ add_test( tests & specification, test const & test_case )
+ {
+ specification.push_back( test_case );
+ }
+};
+
+#else
+
+struct add_module
+{
+ template <std::size_t N>
+ add_module( tests & specification, test const (&module)[N] )
+ {
+ specification.insert( specification.end(), std::begin( module ), std::end( module ) );
+ }
+};
+
+#endif
+
+struct result
+{
+ const bool passed;
+ const text decomposition;
+
+ template< typename T >
+ result( T const & passed, text decomposition )
+ : passed( !!passed ), decomposition( decomposition ) {}
+
+ explicit operator bool() { return ! passed; }
+};
+
+struct location
+{
+ const text file;
+ const int line;
+
+ location( text file, int line )
+ : file( file ), line( line ) {}
+};
+
+struct comment
+{
+ const text info;
+
+ comment( text info ) : info( info ) {}
+ explicit operator bool() { return ! info.empty(); }
+};
+
+struct message : std::runtime_error
+{
+ const text kind;
+ const location where;
+ const comment note;
+
+ ~message() throw() {} // GCC 4.6
+
+ message( text kind, location where, text expr, text note = "" )
+ : std::runtime_error( expr ), kind( kind ), where( where ), note( note ) {}
+};
+
+struct failure : message
+{
+ failure( location where, text expr, text decomposition )
+ : message{ "failed", where, expr + " for " + decomposition } {}
+};
+
+struct success : message
+{
+// using message::message; // VC is lagging here
+
+ success( text kind, location where, text expr, text note = "" )
+ : message( kind, where, expr, note ) {}
+};
+
+struct passing : success
+{
+ passing( location where, text expr, text decomposition )
+ : success( "passed", where, expr + " for " + decomposition ) {}
+};
+
+struct got_none : success
+{
+ got_none( location where, text expr )
+ : success( "passed: got no exception", where, expr ) {}
+};
+
+struct got : success
+{
+ got( location where, text expr )
+ : success( "passed: got exception", where, expr ) {}
+
+ got( location where, text expr, text excpt )
+ : success( "passed: got exception " + excpt, where, expr ) {}
+};
+
+struct expected : message
+{
+ expected( location where, text expr, text excpt = "" )
+ : message{ "failed: didn't get exception", where, expr, excpt } {}
+};
+
+struct unexpected : message
+{
+ unexpected( location where, text expr, text note = "" )
+ : message{ "failed: got unexpected exception", where, expr, note } {}
+};
+
+struct guard
+{
+ int & id;
+ int const & section;
+
+ guard( int & id, int const & section, int & count )
+ : id( id ), section( section )
+ {
+ if ( section == 0 )
+ id = count++ - 1;
+ }
+ operator bool() { return id == section; }
+};
+
+class approx
+{
+public:
+ explicit approx ( double magnitude )
+ : epsilon_ { std::numeric_limits<float>::epsilon() * 100 }
+ , scale_ { 1.0 }
+ , magnitude_{ magnitude } {}
+
+ approx( approx const & other ) = default;
+
+ static approx custom() { return approx( 0 ); }
+
+ approx operator()( double magnitude )
+ {
+ approx approx ( magnitude );
+ approx.epsilon( epsilon_ );
+ approx.scale ( scale_ );
+ return approx;
+ }
+
+ double magnitude() const { return magnitude_; }
+
+ approx & epsilon( double epsilon ) { epsilon_ = epsilon; return *this; }
+ approx & scale ( double scale ) { scale_ = scale; return *this; }
+
+ friend bool operator == ( double lhs, approx const & rhs )
+ {
+ // Thanks to Richard Harris for his help refining this formula.
+ return std::abs( lhs - rhs.magnitude_ ) < rhs.epsilon_ * ( rhs.scale_ + (std::min)( std::abs( lhs ), std::abs( rhs.magnitude_ ) ) );
+ }
+
+ friend bool operator == ( approx const & lhs, double rhs ) { return operator==( rhs, lhs ); }
+ friend bool operator != ( double lhs, approx const & rhs ) { return !operator==( lhs, rhs ); }
+ friend bool operator != ( approx const & lhs, double rhs ) { return !operator==( rhs, lhs ); }
+
+ friend bool operator <= ( double lhs, approx const & rhs ) { return lhs < rhs.magnitude_ || lhs == rhs; }
+ friend bool operator <= ( approx const & lhs, double rhs ) { return lhs.magnitude_ < rhs || lhs == rhs; }
+ friend bool operator >= ( double lhs, approx const & rhs ) { return lhs > rhs.magnitude_ || lhs == rhs; }
+ friend bool operator >= ( approx const & lhs, double rhs ) { return lhs.magnitude_ > rhs || lhs == rhs; }
+
+private:
+ double epsilon_;
+ double scale_;
+ double magnitude_;
+};
+
+inline bool is_false( ) { return false; }
+inline bool is_true ( bool flag ) { return flag; }
+
+inline text not_expr( text message )
+{
+ return "! ( " + message + " )";
+}
+
+inline text with_message( text message )
+{
+ return "with message \"" + message + "\"";
+}
+
+inline text of_type( text type )
+{
+ return "of type " + type;
+}
+
+inline void inform( location where, text expr )
+{
+ try
+ {
+ throw;
+ }
+ catch( message const & )
+ {
+ throw;
+ }
+ catch( std::exception const & e )
+ {
+ throw unexpected{ where, expr, with_message( e.what() ) }; \
+ }
+ catch(...)
+ {
+ throw unexpected{ where, expr, "of unknown type" }; \
+ }
+}
+
+// Expression decomposition:
+
+template<typename T>
+auto make_value_string( T const & value ) -> std::string;
+
+template<typename T>
+auto make_memory_string( T const & item ) -> std::string;
+
+#if lest_FEATURE_LITERAL_SUFFIX
+inline char const * sfx( char const * text ) { return text; }
+#else
+inline char const * sfx( char const * ) { return ""; }
+#endif
+
+inline std::string to_string( std::nullptr_t ) { return "nullptr"; }
+inline std::string to_string( std::string const & text ) { return "\"" + text + "\"" ; }
+#if lest_FEATURE_WSTRING
+inline std::string to_string( std::wstring const & text ) ;
+#endif
+
+inline std::string to_string( char const * const text ) { return text ? to_string( std::string ( text ) ) : "{null string}"; }
+inline std::string to_string( char * const text ) { return text ? to_string( std::string ( text ) ) : "{null string}"; }
+#if lest_FEATURE_WSTRING
+inline std::string to_string( wchar_t const * const text ) { return text ? to_string( std::wstring( text ) ) : "{null string}"; }
+inline std::string to_string( wchar_t * const text ) { return text ? to_string( std::wstring( text ) ) : "{null string}"; }
+#endif
+
+inline std::string to_string( char text ) { return "\'" + std::string( 1, text ) + "\'" ; }
+inline std::string to_string( signed char text ) { return "\'" + std::string( 1, text ) + "\'" ; }
+inline std::string to_string( unsigned char text ) { return "\'" + std::string( 1, text ) + "\'" ; }
+
+inline std::string to_string( bool flag ) { return flag ? "true" : "false"; }
+
+inline std::string to_string( signed short value ) { return make_value_string( value ) ; }
+inline std::string to_string( unsigned short value ) { return make_value_string( value ) + sfx("u" ); }
+inline std::string to_string( signed int value ) { return make_value_string( value ) ; }
+inline std::string to_string( unsigned int value ) { return make_value_string( value ) + sfx("u" ); }
+inline std::string to_string( signed long value ) { return make_value_string( value ) + sfx("l" ); }
+inline std::string to_string( unsigned long value ) { return make_value_string( value ) + sfx("ul" ); }
+inline std::string to_string( signed long long value ) { return make_value_string( value ) + sfx("ll" ); }
+inline std::string to_string( unsigned long long value ) { return make_value_string( value ) + sfx("ull"); }
+inline std::string to_string( double value ) { return make_value_string( value ) ; }
+inline std::string to_string( float value ) { return make_value_string( value ) + sfx("f" ); }
+
+template<typename T>
+struct is_streamable
+{
+ template<typename U>
+ static auto test( int ) -> decltype( std::declval<std::ostream &>() << std::declval<U>(), std::true_type() );
+
+ template<typename>
+ static auto test( ... ) -> std::false_type;
+
+#ifdef _MSC_VER
+ enum { value = std::is_same< decltype( test<T>(0) ), std::true_type >::value };
+#else
+ static constexpr bool value = std::is_same< decltype( test<T>(0) ), std::true_type >::value;
+#endif
+};
+
+template<typename T>
+struct is_container
+{
+ template<typename U>
+ static auto test( int ) -> decltype( std::declval<U>().begin() == std::declval<U>().end(), std::true_type() );
+
+ template<typename>
+ static auto test( ... ) -> std::false_type;
+
+#ifdef _MSC_VER
+ enum { value = std::is_same< decltype( test<T>(0) ), std::true_type >::value };
+#else
+ static constexpr bool value = std::is_same< decltype( test<T>(0) ), std::true_type >::value;
+#endif
+};
+
+template <typename T, typename R>
+using ForEnum = typename std::enable_if< std::is_enum<T>::value, R>::type;
+
+template <typename T, typename R>
+using ForNonEnum = typename std::enable_if< ! std::is_enum<T>::value, R>::type;
+
+template <typename T, typename R>
+using ForStreamable = typename std::enable_if< is_streamable<T>::value, R>::type;
+
+template <typename T, typename R>
+using ForNonStreamable = typename std::enable_if< ! is_streamable<T>::value, R>::type;
+
+template <typename T, typename R>
+using ForContainer = typename std::enable_if< is_container<T>::value, R>::type;
+
+template <typename T, typename R>
+using ForNonContainer = typename std::enable_if< ! is_container<T>::value, R>::type;
+
+template<typename T>
+auto make_enum_string( T const & ) -> ForNonEnum<T, std::string>
+{
+#if lest__cpp_rtti
+ return text("[type: ") + typeid(T).name() + "]";
+#else
+ return text("[type: (no RTTI)]");
+#endif
+}
+
+template<typename T>
+auto make_enum_string( T const & item ) -> ForEnum<T, std::string>
+{
+ return to_string( static_cast<typename std::underlying_type<T>::type>( item ) );
+}
+
+template<typename T>
+auto make_string( T const & item ) -> ForNonStreamable<T, std::string>
+{
+ return make_enum_string( item );
+}
+
+template<typename T>
+auto make_string( T const & item ) -> ForStreamable<T, std::string>
+{
+ std::ostringstream os; os << item; return os.str();
+}
+
+template<typename T>
+auto make_string( T * p )-> std::string
+{
+ if ( p ) return make_memory_string( p );
+ else return "NULL";
+}
+
+template<typename C, typename R>
+auto make_string( R C::* p ) -> std::string
+{
+ if ( p ) return make_memory_string( p );
+ else return "NULL";
+}
+
+template<typename T1, typename T2>
+auto make_string( std::pair<T1,T2> const & pair ) -> std::string
+{
+ std::ostringstream oss;
+ oss << "{ " << to_string( pair.first ) << ", " << to_string( pair.second ) << " }";
+ return oss.str();
+}
+
+template<typename TU, std::size_t N>
+struct make_tuple_string
+{
+ static std::string make( TU const & tuple )
+ {
+ std::ostringstream os;
+ os << to_string( std::get<N - 1>( tuple ) ) << ( N < std::tuple_size<TU>::value ? ", ": " ");
+ return make_tuple_string<TU, N - 1>::make( tuple ) + os.str();
+ }
+};
+
+template<typename TU>
+struct make_tuple_string<TU, 0>
+{
+ static std::string make( TU const & ) { return ""; }
+};
+
+template<typename ...TS>
+auto make_string( std::tuple<TS...> const & tuple ) -> std::string
+{
+ return "{ " + make_tuple_string<std::tuple<TS...>, sizeof...(TS)>::make( tuple ) + "}";
+}
+
+template<typename T>
+auto to_string( T const & item ) -> ForNonContainer<T, std::string>
+{
+ return make_string( item );
+}
+
+template<typename C>
+auto to_string( C const & cont ) -> ForContainer<C, std::string>
+{
+ std::ostringstream os;
+ os << "{ ";
+ for ( auto & x : cont )
+ {
+ os << to_string( x ) << ", ";
+ }
+ os << "}";
+ return os.str();
+}
+
+#if lest_FEATURE_WSTRING
+inline
+auto to_string( std::wstring const & text ) -> std::string
+{
+ std::string result; result.reserve( text.size() );
+
+ for( auto & chr : text )
+ {
+ result += chr <= 0xff ? static_cast<char>( chr ) : '?';
+ }
+ return to_string( result );
+}
+#endif
+
+template<typename T>
+auto make_value_string( T const & value ) -> std::string
+{
+ std::ostringstream os; os << value; return os.str();
+}
+
+inline
+auto make_memory_string( void const * item, std::size_t size ) -> std::string
+{
+ // reverse order for little endian architectures:
+
+ auto is_little_endian = []
+ {
+ union U { int i = 1; char c[ sizeof(int) ]; };
+
+ return 1 != U{}.c[ sizeof(int) - 1 ];
+ };
+
+ int i = 0, end = static_cast<int>( size ), inc = 1;
+
+ if ( is_little_endian() ) { i = end - 1; end = inc = -1; }
+
+ unsigned char const * bytes = static_cast<unsigned char const *>( item );
+
+ std::ostringstream os;
+ os << "0x" << std::setfill( '0' ) << std::hex;
+ for ( ; i != end; i += inc )
+ {
+ os << std::setw(2) << static_cast<unsigned>( bytes[i] ) << " ";
+ }
+ return os.str();
+}
+
+template<typename T>
+auto make_memory_string( T const & item ) -> std::string
+{
+ return make_memory_string( &item, sizeof item );
+}
+
+inline
+auto to_string( approx const & appr ) -> std::string
+{
+ return to_string( appr.magnitude() );
+}
+
+template <typename L, typename R>
+auto to_string( L const & lhs, std::string op, R const & rhs ) -> std::string
+{
+ std::ostringstream os; os << to_string( lhs ) << " " << op << " " << to_string( rhs ); return os.str();
+}
+
+template <typename L>
+struct expression_lhs
+{
+ const L lhs;
+
+ expression_lhs( L lhs ) : lhs( lhs ) {}
+
+ operator result() { return result{ !!lhs, to_string( lhs ) }; }
+
+ template <typename R> result operator==( R const & rhs ) { return result{ lhs == rhs, to_string( lhs, "==", rhs ) }; }
+ template <typename R> result operator!=( R const & rhs ) { return result{ lhs != rhs, to_string( lhs, "!=", rhs ) }; }
+ template <typename R> result operator< ( R const & rhs ) { return result{ lhs < rhs, to_string( lhs, "<" , rhs ) }; }
+ template <typename R> result operator<=( R const & rhs ) { return result{ lhs <= rhs, to_string( lhs, "<=", rhs ) }; }
+ template <typename R> result operator> ( R const & rhs ) { return result{ lhs > rhs, to_string( lhs, ">" , rhs ) }; }
+ template <typename R> result operator>=( R const & rhs ) { return result{ lhs >= rhs, to_string( lhs, ">=", rhs ) }; }
+};
+
+struct expression_decomposer
+{
+ template <typename L>
+ expression_lhs<L const &> operator<< ( L const & operand )
+ {
+ return expression_lhs<L const &>( operand );
+ }
+};
+
+// Reporter:
+
+#if lest_FEATURE_COLOURISE
+
+inline text red ( text words ) { return "\033[1;31m" + words + "\033[0m"; }
+inline text green( text words ) { return "\033[1;32m" + words + "\033[0m"; }
+inline text gray ( text words ) { return "\033[1;30m" + words + "\033[0m"; }
+
+inline bool starts_with( text words, text with )
+{
+ return 0 == words.find( with );
+}
+
+inline text replace( text words, text from, text to )
+{
+ size_t pos = words.find( from );
+ return pos == std::string::npos ? words : words.replace( pos, from.length(), to );
+}
+
+inline text colour( text words )
+{
+ if ( starts_with( words, "failed" ) ) return replace( words, "failed", red ( "failed" ) );
+ else if ( starts_with( words, "passed" ) ) return replace( words, "passed", green( "passed" ) );
+
+ return replace( words, "for", gray( "for" ) );
+}
+
+inline bool is_cout( std::ostream & os ) { return &os == &std::cout; }
+
+struct colourise
+{
+ const text words;
+
+ colourise( text words )
+ : words( words ) {}
+
+ // only colourise for std::cout, not for a stringstream as used in tests:
+
+ std::ostream & operator()( std::ostream & os ) const
+ {
+ return is_cout( os ) ? os << colour( words ) : os << words;
+ }
+};
+
+inline std::ostream & operator<<( std::ostream & os, colourise words ) { return words( os ); }
+#else
+inline text colourise( text words ) { return words; }
+#endif
+
+inline text pluralise( text word, int n )
+{
+ return n == 1 ? word : word + "s";
+}
+
+inline std::ostream & operator<<( std::ostream & os, comment note )
+{
+ return os << (note ? " " + note.info : "" );
+}
+
+inline std::ostream & operator<<( std::ostream & os, location where )
+{
+#ifdef __GNUG__
+ return os << where.file << ":" << where.line;
+#else
+ return os << where.file << "(" << where.line << ")";
+#endif
+}
+
+inline void report( std::ostream & os, message const & e, text test )
+{
+ os << e.where << ": " << colourise( e.kind ) << e.note << ": " << test << ": " << colourise( e.what() ) << std::endl;
+}
+
+// Test runner:
+
+#if lest_FEATURE_REGEX_SEARCH
+ inline bool search( text re, text line )
+ {
+ return std::regex_search( line, std::regex( re ) );
+ }
+#else
+ inline bool search( text part, text line )
+ {
+ auto case_insensitive_equal = []( char a, char b )
+ {
+ return tolower( a ) == tolower( b );
+ };
+
+ return std::search(
+ line.begin(), line.end(),
+ part.begin(), part.end(), case_insensitive_equal ) != line.end();
+ }
+#endif
+
+inline bool match( texts whats, text line )
+{
+ for ( auto & what : whats )
+ {
+ if ( search( what, line ) )
+ return true;
+ }
+ return false;
+}
+
+inline bool select( text name, texts include )
+{
+ auto none = []( texts args ) { return args.size() == 0; };
+
+#if lest_FEATURE_REGEX_SEARCH
+ auto hidden = []( text name ){ return match( { "\\[\\..*", "\\[hide\\]" }, name ); };
+#else
+ auto hidden = []( text name ){ return match( { "[.", "[hide]" }, name ); };
+#endif
+
+ if ( none( include ) )
+ {
+ return ! hidden( name );
+ }
+
+ bool any = false;
+ for ( auto pos = include.rbegin(); pos != include.rend(); ++pos )
+ {
+ auto & part = *pos;
+
+ if ( part == "@" || part == "*" )
+ return true;
+
+ if ( search( part, name ) )
+ return true;
+
+ if ( '!' == part[0] )
+ {
+ any = true;
+ if ( search( part.substr(1), name ) )
+ return false;
+ }
+ else
+ {
+ any = false;
+ }
+ }
+ return any && ! hidden( name );
+}
+
+inline int indefinite( int repeat ) { return repeat == -1; }
+
+using seed_t = unsigned long;
+
+struct options
+{
+ bool help = false;
+ bool abort = false;
+ bool count = false;
+ bool list = false;
+ bool tags = false;
+ bool time = false;
+ bool pass = false;
+ bool lexical = false;
+ bool random = false;
+ bool version = false;
+ int repeat = 1;
+ seed_t seed = 0;
+};
+
+struct env
+{
+ std::ostream & os;
+ bool pass;
+ text testing;
+
+ env( std::ostream & os, bool pass )
+ : os( os ), pass( pass ), testing() {}
+
+ env & operator()( text test )
+ {
+ testing = test; return *this;
+ }
+};
+
+struct action
+{
+ std::ostream & os;
+
+ action( std::ostream & os ) : os( os ) {}
+
+ action( action const & ) = delete;
+ void operator=( action const & ) = delete;
+
+ operator int() { return 0; }
+ bool abort() { return false; }
+ action & operator()( test ) { return *this; }
+};
+
+struct print : action
+{
+ print( std::ostream & os ) : action( os ) {}
+
+ print & operator()( test testing )
+ {
+ os << testing.name << "\n"; return *this;
+ }
+};
+
+inline texts tags( text name, texts result = {} )
+{
+ auto none = std::string::npos;
+ auto lb = name.find_first_of( "[" );
+ auto rb = name.find_first_of( "]" );
+
+ if ( lb == none || rb == none )
+ return result;
+
+ result.emplace_back( name.substr( lb, rb - lb + 1 ) );
+
+ return tags( name.substr( rb + 1 ), result );
+}
+
+struct ptags : action
+{
+ std::set<text> result;
+
+ ptags( std::ostream & os ) : action( os ), result() {}
+
+ ptags & operator()( test testing )
+ {
+ for ( auto & tag : tags( testing.name ) )
+ result.insert( tag );
+
+ return *this;
+ }
+
+ ~ptags()
+ {
+ std::copy( result.begin(), result.end(), std::ostream_iterator<text>( os, "\n" ) );
+ }
+};
+
+struct count : action
+{
+ int n = 0;
+
+ count( std::ostream & os ) : action( os ) {}
+
+ count & operator()( test ) { ++n; return *this; }
+
+ ~count()
+ {
+ os << n << " selected " << pluralise("test", n) << "\n";
+ }
+};
+
+struct timer
+{
+ using time = std::chrono::high_resolution_clock;
+
+ time::time_point start = time::now();
+
+ double elapsed_seconds() const
+ {
+ return 1e-6 * std::chrono::duration_cast< std::chrono::microseconds >( time::now() - start ).count();
+ }
+};
+
+struct times : action
+{
+ env output;
+ options option;
+ int selected = 0;
+ int failures = 0;
+
+ timer total;
+
+ times( std::ostream & os, options option )
+ : action( os ), output( os, option.pass ), option( option ), total()
+ {
+ os << std::setfill(' ') << std::fixed << std::setprecision( lest_FEATURE_TIME_PRECISION );
+ }
+
+ operator int() { return failures; }
+
+ bool abort() { return option.abort && failures > 0; }
+
+ times & operator()( test testing )
+ {
+ timer t;
+
+ try
+ {
+ testing.behaviour( output( testing.name ) );
+ }
+ catch( message const & )
+ {
+ ++failures;
+ }
+
+ os << std::setw(3) << ( 1000 * t.elapsed_seconds() ) << " ms: " << testing.name << "\n";
+
+ return *this;
+ }
+
+ ~times()
+ {
+ os << "Elapsed time: " << std::setprecision(1) << total.elapsed_seconds() << " s\n";
+ }
+};
+
+struct confirm : action
+{
+ env output;
+ options option;
+ int selected = 0;
+ int failures = 0;
+
+ confirm( std::ostream & os, options option )
+ : action( os ), output( os, option.pass ), option( option ) {}
+
+ operator int() { return failures; }
+
+ bool abort() { return option.abort && failures > 0; }
+
+ confirm & operator()( test testing )
+ {
+ try
+ {
+ ++selected; testing.behaviour( output( testing.name ) );
+ }
+ catch( message const & e )
+ {
+ ++failures; report( os, e, testing.name );
+ }
+ return *this;
+ }
+
+ ~confirm()
+ {
+ if ( failures > 0 )
+ {
+ os << failures << " out of " << selected << " selected " << pluralise("test", selected) << " " << colourise( "failed.\n" );
+ }
+ else if ( option.pass )
+ {
+ os << "All " << selected << " selected " << pluralise("test", selected) << " " << colourise( "passed.\n" );
+ }
+ }
+};
+
+template<typename Action>
+bool abort( Action & perform )
+{
+ return perform.abort();
+}
+
+template< typename Action >
+Action && for_test( tests specification, texts in, Action && perform, int n = 1 )
+{
+ for ( int i = 0; indefinite( n ) || i < n; ++i )
+ {
+ for ( auto & testing : specification )
+ {
+ if ( select( testing.name, in ) )
+ if ( abort( perform( testing ) ) )
+ return std::move( perform );
+ }
+ }
+ return std::move( perform );
+}
+
+inline void sort( tests & specification )
+{
+ auto test_less = []( test const & a, test const & b ) { return a.name < b.name; };
+ std::sort( specification.begin(), specification.end(), test_less );
+}
+
+inline void shuffle( tests & specification, options option )
+{
+ std::shuffle( specification.begin(), specification.end(), std::mt19937( option.seed ) );
+}
+
+// workaround MinGW bug, http://stackoverflow.com/a/16132279:
+
+inline int stoi( text num )
+{
+ return static_cast<int>( std::strtol( num.c_str(), nullptr, 10 ) );
+}
+
+inline bool is_number( text arg )
+{
+ return std::all_of( arg.begin(), arg.end(), ::isdigit );
+}
+
+inline seed_t seed( text opt, text arg )
+{
+ if ( is_number( arg ) )
+ return static_cast<seed_t>( lest::stoi( arg ) );
+
+ if ( arg == "time" )
+ return static_cast<seed_t>( std::chrono::high_resolution_clock::now().time_since_epoch().count() );
+
+ throw std::runtime_error( "expecting 'time' or positive number with option '" + opt + "', got '" + arg + "' (try option --help)" );
+}
+
+inline int repeat( text opt, text arg )
+{
+ const int num = lest::stoi( arg );
+
+ if ( indefinite( num ) || num >= 0 )
+ return num;
+
+ throw std::runtime_error( "expecting '-1' or positive number with option '" + opt + "', got '" + arg + "' (try option --help)" );
+}
+
+inline auto split_option( text arg ) -> std::tuple<text, text>
+{
+ auto pos = arg.rfind( '=' );
+
+ return pos == text::npos
+ ? std::make_tuple( arg, "" )
+ : std::make_tuple( arg.substr( 0, pos ), arg.substr( pos + 1 ) );
+}
+
+inline auto split_arguments( texts args ) -> std::tuple<options, texts>
+{
+ options option; texts in;
+
+ bool in_options = true;
+
+ for ( auto & arg : args )
+ {
+ if ( in_options )
+ {
+ text opt, val;
+ std::tie( opt, val ) = split_option( arg );
+
+ if ( opt[0] != '-' ) { in_options = false; }
+ else if ( opt == "--" ) { in_options = false; continue; }
+ else if ( opt == "-h" || "--help" == opt ) { option.help = true; continue; }
+ else if ( opt == "-a" || "--abort" == opt ) { option.abort = true; continue; }
+ else if ( opt == "-c" || "--count" == opt ) { option.count = true; continue; }
+ else if ( opt == "-g" || "--list-tags" == opt ) { option.tags = true; continue; }
+ else if ( opt == "-l" || "--list-tests" == opt ) { option.list = true; continue; }
+ else if ( opt == "-t" || "--time" == opt ) { option.time = true; continue; }
+ else if ( opt == "-p" || "--pass" == opt ) { option.pass = true; continue; }
+ else if ( "--version" == opt ) { option.version = true; continue; }
+ else if ( opt == "--order" && "declared" == val ) { /* by definition */ ; continue; }
+ else if ( opt == "--order" && "lexical" == val ) { option.lexical = true; continue; }
+ else if ( opt == "--order" && "random" == val ) { option.random = true; continue; }
+ else if ( opt == "--random-seed" ) { option.seed = seed ( "--random-seed", val ); continue; }
+ else if ( opt == "--repeat" ) { option.repeat = repeat( "--repeat" , val ); continue; }
+ else throw std::runtime_error( "unrecognised option '" + arg + "' (try option --help)" );
+ }
+ in.push_back( arg );
+ }
+ return std::make_tuple( option, in );
+}
+
+inline int usage( std::ostream & os )
+{
+ os <<
+ "\nUsage: test [options] [test-spec ...]\n"
+ "\n"
+ "Options:\n"
+ " -h, --help this help message\n"
+ " -a, --abort abort at first failure\n"
+ " -c, --count count selected tests\n"
+ " -g, --list-tags list tags of selected tests\n"
+ " -l, --list-tests list selected tests\n"
+ " -p, --pass also report passing tests\n"
+ " -t, --time list duration of selected tests\n"
+ " --order=declared use source code test order (default)\n"
+ " --order=lexical use lexical sort test order\n"
+ " --order=random use random test order\n"
+ " --random-seed=n use n for random generator seed\n"
+ " --random-seed=time use time for random generator seed\n"
+ " --repeat=n repeat selected tests n times (-1: indefinite)\n"
+ " --version report lest version and compiler used\n"
+ " -- end options\n"
+ "\n"
+ "Test specification:\n"
+ " \"@\", \"*\" all tests, unless excluded\n"
+ " empty all tests, unless tagged [hide] or [.optional-name]\n"
+#if lest_FEATURE_REGEX_SEARCH
+ " \"re\" select tests that match regular expression\n"
+ " \"!re\" omit tests that match regular expression\n"
+#else
+ " \"text\" select tests that contain text (case insensitive)\n"
+ " \"!text\" omit tests that contain text (case insensitive)\n"
+#endif
+ ;
+ return 0;
+}
+
+inline text compiler()
+{
+ std::ostringstream os;
+#if defined (__clang__ )
+ os << "clang " << __clang_version__;
+#elif defined (__GNUC__ )
+ os << "gcc " << __GNUC__ << "." << __GNUC_MINOR__ << "." << __GNUC_PATCHLEVEL__;
+#elif defined ( _MSC_VER )
+ os << "MSVC " << (_MSC_VER / 100 - 5 - (_MSC_VER < 1900)) << " (" << _MSC_VER << ")";
+#else
+ os << "[compiler]";
+#endif
+ return os.str();
+}
+
+inline int version( std::ostream & os )
+{
+ os << "lest version " << lest_VERSION << "\n"
+ << "Compiled with " << compiler() << " on " << __DATE__ << " at " << __TIME__ << ".\n"
+ << "For more information, see https://github.com/martinmoene/lest.\n";
+ return 0;
+}
+
+inline int run( tests specification, texts arguments, std::ostream & os = std::cout )
+{
+ try
+ {
+ options option; texts in;
+ std::tie( option, in ) = split_arguments( arguments );
+
+ if ( option.lexical ) { sort( specification ); }
+ if ( option.random ) { shuffle( specification, option ); }
+
+ if ( option.help ) { return usage ( os ); }
+ if ( option.version ) { return version ( os ); }
+ if ( option.count ) { return for_test( specification, in, count( os ) ); }
+ if ( option.list ) { return for_test( specification, in, print( os ) ); }
+ if ( option.tags ) { return for_test( specification, in, ptags( os ) ); }
+ if ( option.time ) { return for_test( specification, in, times( os, option ) ); }
+
+ return for_test( specification, in, confirm( os, option ), option.repeat );
+ }
+ catch ( std::exception const & e )
+ {
+ os << "Error: " << e.what() << "\n";
+ return 1;
+ }
+}
+
+inline int run( tests specification, int argc, char * argv[], std::ostream & os = std::cout )
+{
+ return run( specification, texts( argv + 1, argv + argc ), os );
+}
+
+template <std::size_t N>
+int run( test const (&specification)[N], texts arguments, std::ostream & os = std::cout )
+{
+ return run( tests( specification, specification + N ), arguments, os );
+}
+
+template <std::size_t N>
+int run( test const (&specification)[N], std::ostream & os = std::cout )
+{
+ return run( tests( specification, specification + N ), {}, os );
+}
+
+template <std::size_t N>
+int run( test const (&specification)[N], int argc, char * argv[], std::ostream & os = std::cout )
+{
+ return run( tests( specification, specification + N ), texts( argv + 1, argv + argc ), os );
+}
+
+} // namespace lest
+
+#endif // LEST_LEST_HPP_INCLUDED
diff --git a/test/cpp/api/misc_t.cpp b/test/cpp/api/misc_t.cpp
new file mode 100644
index 0000000..686b8ed
--- /dev/null
+++ b/test/cpp/api/misc_t.cpp
@@ -0,0 +1,35 @@
+#include "test.h"
+
+CASE("misc/no_grad/1") {
+ no_grad_guard guard;
+ auto model = Linear(5, 2).make();
+ auto x = Var(at::CPU(at::kFloat).randn({10, 5}), true);
+ auto y = model->forward({x})[0];
+ Variable s = y.sum();
+
+ backward(s);
+ EXPECT(!model->parameters()["weight"].grad().defined());
+};
+
+CASE("misc/random/seed_cpu") {
+ int size = 100;
+ setSeed(7);
+ auto x1 = Var(at::CPU(at::kFloat).randn({size}));
+ setSeed(7);
+ auto x2 = Var(at::CPU(at::kFloat).randn({size}));
+
+ auto l_inf = (x1.data() - x2.data()).abs().max().toCFloat();
+ EXPECT(l_inf < 1e-10);
+};
+
+CASE("misc/random/seed_cuda") {
+ CUDA_GUARD;
+ int size = 100;
+ setSeed(7);
+ auto x1 = Var(at::CUDA(at::kFloat).randn({size}));
+ setSeed(7);
+ auto x2 = Var(at::CUDA(at::kFloat).randn({size}));
+
+ auto l_inf = (x1.data() - x2.data()).abs().max().toCFloat();
+ EXPECT(l_inf < 1e-10);
+};
diff --git a/test/cpp/api/optim_t.cpp b/test/cpp/api/optim_t.cpp
new file mode 100644
index 0000000..8042f68
--- /dev/null
+++ b/test/cpp/api/optim_t.cpp
@@ -0,0 +1,99 @@
+#include "test.h"
+
+bool test_optimizer_xor(Optimizer optim, std::shared_ptr<ContainerList> model) {
+ float running_loss = 1;
+ int epoch = 0;
+ while (running_loss > 0.1) {
+ auto bs = 4U;
+ auto inp = at::CPU(at::kFloat).tensor({bs, 2});
+ auto lab = at::CPU(at::kFloat).tensor({bs});
+ for (auto i = 0U; i < bs; i++) {
+ auto a = std::rand() % 2;
+ auto b = std::rand() % 2;
+ auto c = a ^ b;
+ inp[i][0] = a;
+ inp[i][1] = b;
+ lab[i] = c;
+ }
+ // forward
+ auto x = Var(inp);
+ auto y = Var(lab, false);
+ for (auto layer : *model) x = layer->forward({x})[0].sigmoid_();
+ Variable loss = at::binary_cross_entropy(x, y);
+
+ optim->zero_grad();
+ backward(loss);
+ optim->step();
+
+ running_loss = running_loss * 0.99 + loss.data().sum().toCFloat() * 0.01;
+ if (epoch > 3000) {
+ return false;
+ }
+ epoch++;
+ }
+ return true;
+}
+
+CASE("optim/sgd") {
+ auto model = ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+
+ auto optim = SGD(model, 1e-1).momentum(0.9).nesterov().weight_decay(1e-6).make();
+ EXPECT(test_optimizer_xor(optim, model));
+}
+
+CASE("optim/adagrad") {
+ auto model = ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+
+ auto optim = Adagrad(model, 1.0).weight_decay(1e-6).lr_decay(1e-3).make();
+ EXPECT(test_optimizer_xor(optim, model));
+}
+
+CASE("optim/rmsprop") {
+ {
+ auto model = ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+
+ auto optim = RMSprop(model, 1e-1).momentum(0.9).weight_decay(1e-6).make();
+ EXPECT(test_optimizer_xor(optim, model));
+ }
+
+ {
+ auto model = ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+
+ auto optim = RMSprop(model, 1e-1).centered().make();
+ EXPECT(test_optimizer_xor(optim, model));
+ }
+}
+
+CASE("optim/adam") {
+ auto model = ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+
+ auto optim = Adam(model, 1.0).weight_decay(1e-6).make();
+ EXPECT(test_optimizer_xor(optim, model));
+}
+
+CASE("optim/amsgrad") {
+ auto model = ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+
+ auto optim = Adam(model, 0.1).weight_decay(1e-6).amsgrad().make();
+ EXPECT(test_optimizer_xor(optim, model));
+}
+
+
diff --git a/test/cpp/api/rnn_t.cpp b/test/cpp/api/rnn_t.cpp
new file mode 100644
index 0000000..3ed4ed3
--- /dev/null
+++ b/test/cpp/api/rnn_t.cpp
@@ -0,0 +1,168 @@
+#include "test.h"
+
+template <typename R, typename Func>
+bool test_RNN_xor(Func&& model_maker, bool cuda = false) {
+ auto nhid = 32;
+ auto model = SimpleContainer().make();
+ auto l1 = model->add(Linear(1, nhid).make(), "l1");
+ auto rnn = model->add(model_maker(nhid), "rnn");
+ auto lo = model->add(Linear(nhid, 1).make(), "lo");
+
+ auto optim = Adam(model, 1e-2).make();
+
+ auto forward_op = [&](Variable x) {
+ auto T = x.size(0);
+ auto B = x.size(1);
+ x = x.view({T * B, 1});
+ x = l1->forward({x})[0].view({T, B, nhid}).tanh_();
+ x = rnn->forward({x})[0][T-1];
+ x = lo->forward({x})[0];
+ return x;
+ };
+
+ if (cuda) {
+ model->cuda();
+ }
+
+ float running_loss = 1;
+ int epoch = 0;
+ auto max_epoch = 1500;
+ while (running_loss > 1e-2) {
+ auto bs = 16U;
+ auto nlen = 5U;
+ auto inp = at::CPU(at::kFloat).rand({nlen, bs, 1}).round().toType(at::kFloat);
+ auto lab = inp.sum(0);
+
+ if (cuda) {
+ inp = inp.toBackend(at::kCUDA);
+ lab = lab.toBackend(at::kCUDA);
+ }
+
+ auto x = Var(inp);
+ auto y = Var(lab, false);
+ x = forward_op(x);
+ Variable loss = at::mse_loss(x, y);
+
+ optim->zero_grad();
+ backward(loss);
+ optim->step();
+
+ running_loss = running_loss * 0.99 + loss.toCFloat() * 0.01;
+ if (epoch > max_epoch) {
+ return false;
+ }
+ epoch++;
+ }
+ return true;
+};
+
+CASE("RNN/LSTM/sizes") {
+ auto model = LSTM(128, 64).nlayers(2).dropout(0.2).make();
+ Variable x = Var(at::CPU(at::kFloat).randn({10, 16, 128}));
+ auto tup = model->forward({x});
+ auto y = x.mean();
+
+ auto out = tup[0];
+ auto hids = tup[1];
+
+ backward(y);
+ EXPECT(out.ndimension() == 3);
+ EXPECT(out.size(0) == 10);
+ EXPECT(out.size(1) == 16);
+ EXPECT(out.size(2) == 64);
+
+ EXPECT(hids.ndimension() == 4);
+ EXPECT(hids.size(0) == 2); // 2 layers
+ EXPECT(hids.size(1) == 2); // c and h
+ EXPECT(hids.size(2) == 16); // Batch size of 16
+ EXPECT(hids.size(3) == 64); // 64 hidden dims
+
+ // Something is in the hiddens
+ EXPECT(hids.norm().toCFloat() > 0);
+
+ Variable diff = model->forward({x, hids})[1] - hids;
+
+ // Hiddens changed
+ EXPECT(diff.data().abs().sum().toCFloat() > 1e-3);
+};
+
+CASE("RNN/LSTM/outputs") {
+ // Make sure the outputs match pytorch outputs
+ auto model = LSTM(2, 2).make();
+ for (auto& v : model->parameters()) {
+ float size = v.second.numel();
+ auto p = static_cast<float*>(v.second.data().storage()->data());
+ for (size_t i = 0; i < size; i++) {
+ p[i] = i/size;
+ }
+ }
+
+ Variable x = Var(at::CPU(at::kFloat).tensor({3, 4, 2}));
+ float size = x.data().numel();
+ auto p = static_cast<float*>(x.data().storage()->data());
+ for (size_t i = 0; i < size; i++) {
+ p[i] = (size - i) / size;
+ }
+
+ auto out = model->forward({x});
+ EXPECT(out[0].ndimension() == 3);
+ EXPECT(out[0].size(0) == 3);
+ EXPECT(out[0].size(1) == 4);
+ EXPECT(out[0].size(2) == 2);
+
+ auto flat = out[0].data().view(3*4*2);
+ float c_out[] = {0.4391, 0.5402, 0.4330, 0.5324, 0.4261, 0.5239, 0.4183,
+ 0.5147, 0.6822, 0.8064, 0.6726, 0.7968, 0.6620, 0.7860, 0.6501, 0.7741,
+ 0.7889, 0.9003, 0.7769, 0.8905, 0.7635, 0.8794, 0.7484, 0.8666};
+ for (size_t i = 0; i < 3*4*2; i++) {
+ EXPECT(std::abs(flat[i].toCFloat() - c_out[i]) < 1e-3);
+ }
+
+ EXPECT(out[1].ndimension() == 4); // T x (hx, cx) x B x 2
+ EXPECT(out[1].size(0) == 1);
+ EXPECT(out[1].size(1) == 2);
+ EXPECT(out[1].size(2) == 4);
+ EXPECT(out[1].size(3) == 2);
+ flat = out[1].data().view(16);
+ float h_out[] = {0.7889, 0.9003, 0.7769, 0.8905, 0.7635, 0.8794, 0.7484,
+ 0.8666, 1.1647, 1.6106, 1.1425, 1.5726, 1.1187, 1.5329, 1.0931, 1.4911};
+ for (size_t i = 0; i < 16; i++) {
+ EXPECT(std::abs(flat[i].toCFloat() - h_out[i]) < 1e-3);
+ }
+};
+
+CASE("integration/RNN/LSTM") {
+ EXPECT(test_RNN_xor<LSTM>([](int s) { return LSTM(s, s).nlayers(2).make(); }));
+};
+
+CASE("integration/RNN/GRU") {
+ EXPECT(test_RNN_xor<GRU>([](int s) { return GRU(s, s).nlayers(2).make(); }));
+};
+
+CASE("integration/RNN/RNN/Relu") {
+ EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Relu).nlayers(2).make(); }));
+};
+
+CASE("integration/RNN/RNN/Tanh") {
+ EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Tanh).nlayers(2).make(); }));
+};
+
+CASE("integration/RNN/cuda/LSTM") {
+ CUDA_GUARD;
+ EXPECT(test_RNN_xor<LSTM>([](int s) { return LSTM(s, s).nlayers(2).make(); }, true));
+};
+
+CASE("integration/RNN/cuda/GRU") {
+ CUDA_GUARD;
+ EXPECT(test_RNN_xor<GRU>([](int s) { return GRU(s, s).nlayers(2).make(); }, true));
+};
+
+CASE("integration/RNN/cuda/RNN/Relu") {
+ CUDA_GUARD;
+ EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Relu).nlayers(2).make(); }, true));
+};
+
+CASE("integration/RNN/cuda/RNN/Tanh") {
+ CUDA_GUARD;
+ EXPECT(test_RNN_xor<RNN>([](int s) { return RNN(s, s, RNN::Mode::Tanh).nlayers(2).make(); }, true));
+};
diff --git a/test/cpp/api/serialization_t.cpp b/test/cpp/api/serialization_t.cpp
new file mode 100644
index 0000000..54efb61
--- /dev/null
+++ b/test/cpp/api/serialization_t.cpp
@@ -0,0 +1,261 @@
+#include "test.h"
+
+#include "cereal/archives/portable_binary.hpp"
+
+CASE("serialization/undefined") {
+ auto x = at::Tensor();
+
+ EXPECT(!x.defined());
+
+ auto y = at::CPU(at::kFloat).randn({5});
+
+ std::stringstream ss;
+ save(ss, &x);
+ load(ss, &y);
+
+ EXPECT(!y.defined());
+}
+
+CASE("serialization/cputypes") {
+ for (int i = 0; i < static_cast<int>(at::ScalarType::NumOptions); i++) {
+ if (i == static_cast<int>(at::ScalarType::Half)) {
+ // XXX can't serialize half tensors at the moment since contiguous() is
+ // not implemented for this type;
+ continue;
+ } else if (i == static_cast<int>(at::ScalarType::Undefined)) {
+ // We can't construct a tensor for this type. This is tested in
+ // serialization/undefined anyway.
+ continue;
+ }
+
+ auto x =
+ at::getType(at::kCPU, static_cast<at::ScalarType>(i)).ones({5, 5});
+ auto y = at::Tensor();
+
+ std::stringstream ss;
+ save(ss, &x);
+ load(ss, &y);
+
+ EXPECT(y.defined());
+ EXPECT(x.sizes().vec() == y.sizes().vec());
+ if (at::isIntegralType(static_cast<at::ScalarType>(i))) {
+ EXPECT(x.equal(y));
+ } else {
+ EXPECT(x.allclose(y));
+ }
+ }
+}
+
+CASE("serialization/binary") {
+ auto x = at::CPU(at::kFloat).randn({5, 5});
+ auto y = at::Tensor();
+
+ std::stringstream ss;
+ {
+ cereal::BinaryOutputArchive archive(ss);
+ archive(x);
+ }
+ {
+ cereal::BinaryInputArchive archive(ss);
+ archive(y);
+ }
+
+ EXPECT(y.defined());
+ EXPECT(x.sizes().vec() == y.sizes().vec());
+ EXPECT(x.allclose(y));
+}
+
+CASE("serialization/portable_binary") {
+ auto x = at::CPU(at::kFloat).randn({5, 5});
+ auto y = at::Tensor();
+
+ std::stringstream ss;
+ {
+ cereal::PortableBinaryOutputArchive archive(ss);
+ archive(x);
+ }
+ {
+ cereal::PortableBinaryInputArchive archive(ss);
+ archive(y);
+ }
+
+ EXPECT(y.defined());
+ EXPECT(x.sizes().vec() == y.sizes().vec());
+ EXPECT(x.allclose(y));
+}
+
+CASE("serialization/resized") {
+ auto x = at::CPU(at::kFloat).randn({11, 5});
+ x.resize_({5, 5});
+ auto y = at::Tensor();
+
+ std::stringstream ss;
+ {
+ cereal::BinaryOutputArchive archive(ss);
+ archive(x);
+ }
+ {
+ cereal::BinaryInputArchive archive(ss);
+ archive(y);
+ }
+
+ EXPECT(y.defined());
+ EXPECT(x.sizes().vec() == y.sizes().vec());
+ EXPECT(x.allclose(y));
+}
+
+CASE("serialization/sliced") {
+ auto x = at::CPU(at::kFloat).randn({11, 5});
+ x = x.slice(0, 1, 3);
+ auto y = at::Tensor();
+
+ std::stringstream ss;
+ {
+ cereal::BinaryOutputArchive archive(ss);
+ archive(x);
+ }
+ {
+ cereal::BinaryInputArchive archive(ss);
+ archive(y);
+ }
+
+ EXPECT(y.defined());
+ EXPECT(x.sizes().vec() == y.sizes().vec());
+ EXPECT(x.allclose(y));
+}
+
+CASE("serialization/noncontig") {
+ auto x = at::CPU(at::kFloat).randn({11, 5});
+ x = x.slice(1, 1, 4);
+ auto y = at::Tensor();
+
+ std::stringstream ss;
+ {
+ cereal::BinaryOutputArchive archive(ss);
+ archive(x);
+ }
+ {
+ cereal::BinaryInputArchive archive(ss);
+ archive(y);
+ }
+
+ EXPECT(y.defined());
+ EXPECT(x.sizes().vec() == y.sizes().vec());
+ EXPECT(x.allclose(y));
+}
+
+CASE("serialization/xor") {
+ // We better be able to save and load a XOR model!
+ auto makeModel = []() {
+ return ContainerList()
+ .append(Linear(2, 8).make())
+ .append(Linear(8, 1).make())
+ .make();
+ };
+ auto getLoss = [](std::shared_ptr<ContainerList> model, uint32_t bs) {
+ auto inp = at::CPU(at::kFloat).tensor({bs, 2});
+ auto lab = at::CPU(at::kFloat).tensor({bs});
+ for (auto i = 0U; i < bs; i++) {
+ auto a = std::rand() % 2;
+ auto b = std::rand() % 2;
+ auto c = a ^ b;
+ inp[i][0] = a;
+ inp[i][1] = b;
+ lab[i] = c;
+ }
+
+ // forward
+ auto x = Var(inp);
+ auto y = Var(lab, false);
+ for (auto layer : *model) x = layer->forward({x})[0].sigmoid_();
+ return at::binary_cross_entropy(x, y);
+ };
+
+ auto model = makeModel();
+ auto model2 = makeModel();
+ auto model3 = makeModel();
+ auto optim = SGD(model, 1e-1).momentum(0.9).nesterov().weight_decay(1e-6).make();
+
+ float running_loss = 1;
+ int epoch = 0;
+ while (running_loss > 0.1) {
+ Variable loss = getLoss(model, 4);
+ optim->zero_grad();
+ backward(loss);
+ optim->step();
+
+ running_loss = running_loss * 0.99 + loss.data().sum().toCFloat() * 0.01;
+ EXPECT(epoch < 3000);
+ epoch++;
+ }
+
+ std::stringstream ss;
+ save(ss, model);
+ load(ss, model2);
+
+ auto loss = getLoss(model2, 100);
+ EXPECT(loss.toCFloat() < 0.1);
+
+ CUDA_GUARD;
+ model2->cuda();
+ ss.clear();
+ save(ss, model2);
+ load(ss, model3);
+
+ loss = getLoss(model3, 100);
+ EXPECT(loss.toCFloat() < 0.1);
+}
+
+CASE("serialization/optim") {
+ auto model1 = Linear(5, 2).make();
+ auto model2 = Linear(5, 2).make();
+ auto model3 = Linear(5, 2).make();
+
+ // Models 1, 2, 3 will have the same params
+ std::stringstream ss;
+ save(ss, model1);
+ load(ss, model2);
+ ss.seekg(0, std::ios::beg);
+ load(ss, model3);
+
+ // Make some optimizers with momentum (and thus state)
+ auto optim1 = SGD(model1, 1e-1).momentum(0.9).make();
+ auto optim2 = SGD(model2, 1e-1).momentum(0.9).make();
+ auto optim2_2 = SGD(model2, 1e-1).momentum(0.9).make();
+ auto optim3 = SGD(model3, 1e-1).momentum(0.9).make();
+ auto optim3_2 = SGD(model3, 1e-1).momentum(0.9).make();
+
+ auto x = Var(at::CPU(at::kFloat).ones({10, 5}), true);
+
+ auto step = [&](Optimizer optim, Container model) {
+ optim->zero_grad();
+ auto y = model->forward({x})[0].sum();
+ backward(y);
+ optim->step();
+ };
+
+ // Do 2 steps of model1
+ step(optim1, model1);
+ step(optim1, model1);
+
+ // Do 2 steps of model 2 without saving the optimizer
+ step(optim2, model2);
+ step(optim2_2, model2);
+
+ // Do 2 steps of model 3 while saving the optimizer
+ step(optim3, model3);
+ ss.clear();
+ save(ss, optim3);
+ load(ss, optim3_2);
+ step(optim3_2, model3);
+
+ auto param1 = model1->parameters();
+ auto param2 = model2->parameters();
+ auto param3 = model3->parameters();
+ for (auto& p : param1) {
+ auto name = p.first;
+ // Model 1 and 3 should be the same
+ EXPECT(param1[name].norm().toCFloat() == param3[name].norm().toCFloat());
+ EXPECT(param1[name].norm().toCFloat() != param2[name].norm().toCFloat());
+ }
+}
diff --git a/test/cpp/api/test.cpp b/test/cpp/api/test.cpp
new file mode 100644
index 0000000..b1ae62e
--- /dev/null
+++ b/test/cpp/api/test.cpp
@@ -0,0 +1,10 @@
+#include "test.h"
+
+lest::tests & specification() {
+ static lest::tests tests;
+ return tests;
+}
+
+int main( int argc, char * argv[] ) {
+ return lest::run( specification(), argc, argv);
+}
diff --git a/test/cpp/api/test.h b/test/cpp/api/test.h
new file mode 100644
index 0000000..7b19d23
--- /dev/null
+++ b/test/cpp/api/test.h
@@ -0,0 +1,14 @@
+#pragma once
+
+#include "lest.hpp"
+#include <torch/autograd.h>
+
+using namespace autograd;
+
+#define CASE( name ) lest_CASE( specification(), name )
+
+#define CUDA_GUARD if (!hasCuda()) {\
+ std::cerr << "No cuda, skipping test" << std::endl; return;\
+}
+
+extern lest::tests & specification();
diff --git a/third_party/cereal b/third_party/cereal
new file mode 160000
index 0000000..51cbda5
--- /dev/null
+++ b/third_party/cereal
@@ -0,0 +1 @@
+Subproject commit 51cbda5f30e56c801c07fe3d3aba5d7fb9e6cca4
diff --git a/tools/cpp_build/libtorch/CMakeLists.txt b/tools/cpp_build/libtorch/CMakeLists.txt
index ca7e264..f8cfe82 100644
--- a/tools/cpp_build/libtorch/CMakeLists.txt
+++ b/tools/cpp_build/libtorch/CMakeLists.txt
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
cmake_policy(VERSION 3.0)
-set(CMAKE_CXX_STANDARD 14)
+set(CMAKE_CXX_STANDARD 11)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
@@ -155,10 +155,19 @@
${TORCH_SRC_DIR}/csrc/jit/type.cpp
${TORCH_SRC_DIR}/csrc/jit/interpreter_autograd_function.cpp
${TORCH_SRC_DIR}/csrc/Exceptions.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/detail.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/containers.cpp
+ ${TORCH_SRC_DIR}/csrc/api/src/optimizers.cpp
)
add_library(torch SHARED ${TORCH_SRCS})
+target_compile_options(torch PRIVATE -Wall -Wextra)
+
+if ($ENV{WERROR})
+ target_compile_options(torch PRIVATE -Werror)
+endif()
+
target_link_libraries(torch
${TORCH_CUDA_LIBRARIES}
${ATEN_LIBRARY}
@@ -169,6 +178,10 @@
"${ATEN_INCLUDE_DIR}/TH"
"${ATEN_BUILD_INCLUDE_DIR}"
"${ATEN_BUILD_PATH}/src/TH"
+ "${TORCH_SRC_DIR}/csrc/api/"
+ "${TORCH_SRC_DIR}/csrc/api/include"
+ "${TORCH_SRC_DIR}/../third_party/cereal/include" # For cereal/
+ "${TORCH_SRC_DIR}/../"
"${CMAKE_CURRENT_SOURCE_DIR}"
"${CUDA_INCLUDE_DIRS}")
@@ -193,9 +206,9 @@
LIBRARY DESTINATION "${TORCH_INSTALL_LIB_DIR}"
ARCHIVE DESTINATION "${TORCH_INSTALL_LIB_DIR}")
-set(TORCH_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/test_jit.cpp)
+# JIT Tests. TODO: Put into test/cpp/jit folder
-add_executable(test_jit ${TORCH_TEST_SRCS})
+add_executable(test_jit ${TORCH_SRC_DIR}/csrc/jit/test_jit.cpp)
target_link_libraries(test_jit torch)
@@ -204,7 +217,19 @@
"${TORCH_SRC_DIR}/../third_party/catch/single_include"
"${COMMON_INCLUDES}")
-install(TARGETS test_jit
- RUNTIME DESTINATION "${TORCH_INSTALL_BIN_DIR}"
- LIBRARY DESTINATION "${TORCH_INSTALL_LIB_DIR}"
- ARCHIVE DESTINATION "${TORCH_INSTALL_LIB_DIR}")
+# API Tests
+
+set(TORCH_API_TEST_DIR "${TORCH_SRC_DIR}/../test/cpp/api")
+
+add_executable(test_api
+ ${TORCH_API_TEST_DIR}/test.cpp
+ ${TORCH_API_TEST_DIR}/container_t.cpp
+ ${TORCH_API_TEST_DIR}/misc_t.cpp
+ ${TORCH_API_TEST_DIR}/rnn_t.cpp
+ ${TORCH_API_TEST_DIR}/integration_t.cpp
+ ${TORCH_API_TEST_DIR}/optim_t.cpp
+ ${TORCH_API_TEST_DIR}/serialization_t.cpp
+)
+
+target_compile_options(test_api PRIVATE -Dlest_FEATURE_AUTO_REGISTER=1)
+target_link_libraries(test_api torch)
diff --git a/tools/download_mnist.py b/tools/download_mnist.py
new file mode 100644
index 0000000..2e65d60
--- /dev/null
+++ b/tools/download_mnist.py
@@ -0,0 +1,81 @@
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import gzip
+import os
+import sys
+import urllib
+
+try:
+ from urllib.error import URLError
+ from urllib.request import urlretrieve
+except ImportError:
+ from urllib2 import URLError
+ from urllib import urlretrieve
+
+RESOURCES = [
+ 'train-images-idx3-ubyte.gz',
+ 'train-labels-idx1-ubyte.gz',
+ 't10k-images-idx3-ubyte.gz',
+ 't10k-labels-idx1-ubyte.gz',
+]
+
+
+def report_download_progress(chunk_number, chunk_size, file_size):
+ if file_size != -1:
+ percent = min(1, (chunk_number * chunk_size) / file_size)
+ bar = '#' * int(64 * percent)
+ sys.stdout.write('\r0% |{:<64}| {}%'.format(bar, int(percent * 100)))
+
+
+def download(destination_path, url):
+ if os.path.exists(destination_path):
+ print('{} already exists, skipping ...'.format(destination_path))
+ else:
+ print('Downloading {} ...'.format(url))
+ try:
+ urlretrieve(
+ url, destination_path, reporthook=report_download_progress)
+ except URLError:
+ raise RuntimeError('Error downloading resource!')
+ finally:
+ # Just a newline.
+ print()
+
+
+def unzip(zipped_path):
+ unzipped_path = os.path.splitext(zipped_path)[0]
+ if os.path.exists(unzipped_path):
+ print('{} already exists, skipping ... '.format(unzipped_path))
+ return
+ with gzip.open(zipped_path, 'rb') as zipped_file:
+ with open(unzipped_path, 'wb') as unzipped_file:
+ unzipped_file.write(zipped_file.read())
+ print('Unzipped {} ...'.format(zipped_path))
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Download the MNIST dataset from the internet')
+ parser.add_argument(
+ '-d', '--destination', default='.', help='Destination directory')
+ options = parser.parse_args()
+
+ if not os.path.exists(options.destination):
+ os.makedirs(options.destination)
+
+ try:
+ for resource in RESOURCES:
+ path = os.path.join(options.destination, resource)
+ url = 'http://yann.lecun.com/exdb/mnist/{}'.format(resource)
+ download(path, url)
+ unzip(path)
+ except KeyboardInterrupt:
+ print('Interrupted')
+
+ print('Done')
+
+
+if __name__ == '__main__':
+ main()
diff --git a/torch/csrc/api/README.md b/torch/csrc/api/README.md
new file mode 100644
index 0000000..10ff141
--- /dev/null
+++ b/torch/csrc/api/README.md
@@ -0,0 +1,50 @@
+# AUTOGRADPP
+
+This is an experimental C++ frontend to pytorch's C++ backend. Use at your own
+risk.
+
+How to build:
+```
+git submodule update --init --recursive
+
+cd pytorch
+# On Linux:
+python setup.py build
+# On macOS (may need to prefix with `MACOSX_DEPLOYMENT_TARGET=10.9 CC=clang CXX=clang++` when using anaconda)
+LDSHARED="cc -dynamiclib -undefined dynamic_lookup" python setup.py build
+
+cd ..; mkdir -p build; cd build
+cmake .. -DPYTHON_EXECUTABLE:FILEPATH=$(which python) # helpful if you use anaconda
+make -j
+```
+
+# Stuff
+
+- Check out the [MNIST example](https://github.com/ebetica/autogradpp/blob/eee977ddd377c484af5fce09ae8676410bb6fcce/tests/integration_t.cpp#L320-L355),
+which tries to replicate PyTorch's MNIST model + training loop
+- The principled way to write a model is probably something like
+```
+AUTOGRAD_CONTAINER_CLASS(MyModel) {
+ // This does a 2D convolution, followed by global sum pooling, followed by a linear.
+ public:
+ void initialize_containers() override {
+ myConv_ = add(Conv2d(1, 50, 3, 3).stride(2).make(), "conv");
+ myLinear_ = add(Linear(50, 1).make(), "linear");
+ }
+ variable_list forward(variable_list x) override {
+ auto v = myConv_->forward(x);
+ v = v.mean(-1).mean(-1);
+ return myLinear_.forward({v});
+ }
+ private:
+ Container myLinear_;
+ Container myConv_;
+}
+```
+
+Some things are not implemented:
+- SGD, Adagrad, RMSprop, and Adam are the only optimizers implemented
+- Bidirectional, batch first, and PackedSequence are not implemented for LSTMs
+- Sparse Tensors might work but are very untested
+
+Otherwise, lots of other things work. There may be breaking API changes.
diff --git a/torch/csrc/api/include/torch/autograd.h b/torch/csrc/api/include/torch/autograd.h
new file mode 100644
index 0000000..9bd84e4
--- /dev/null
+++ b/torch/csrc/api/include/torch/autograd.h
@@ -0,0 +1,4 @@
+#pragma once
+#include "torch/containers.h"
+#include "torch/optimizers.h"
+#include "torch/serialization.h"
diff --git a/torch/csrc/api/include/torch/containers.h b/torch/csrc/api/include/torch/containers.h
new file mode 100644
index 0000000..8fa6a39
--- /dev/null
+++ b/torch/csrc/api/include/torch/containers.h
@@ -0,0 +1,446 @@
+#pragma once
+
+#include "detail.h"
+
+#include "torch/csrc/autograd/variable.h"
+
+#define AUTOGRAD_CONTAINER_CLASS(Type) \
+ class Type : public autograd::Container_CRTP<Type>
+
+namespace autograd {
+class ContainerImpl {
+ public:
+ // Only construct parameters in initialize_parameters, and
+ // containers in initialize_containers. Most of the time, the containers are
+ // the only thing you need to add.
+ // You are guaranteed that containers are added before parameters.
+ virtual void initialize_containers(){};
+ virtual void initialize_parameters(){};
+ virtual void reset_parameters(){};
+
+ virtual variable_list forward(variable_list) = 0;
+ virtual Container clone() const = 0;
+
+ std::map<std::string, Variable> parameters() const;
+ Variable& param(std::string const&);
+
+ virtual void cuda();
+ virtual void cpu();
+ void train();
+ void eval();
+
+ at::Type& DefaultTensor(at::ScalarType s);
+
+ std::unordered_map<std::string, Container> children_;
+ std::unordered_map<std::string, Variable> params_;
+ bool cuda_ = false;
+ bool train_ = true;
+
+ template <class Archive>
+ void save(Archive& ar) const {
+ auto params = parameters();
+ std::size_t size = params.size();
+ ar(size);
+ for (auto& p : params) {
+ ar(p.first, p.second);
+ }
+ }
+
+ template <class Archive>
+ void load(Archive& ar) {
+ auto params = parameters();
+ std::size_t size;
+ ar(size);
+ std::string name;
+ for (std::size_t i = 0; i < size; i++) {
+ ar(name);
+ ar(params[name]);
+ }
+ }
+
+ protected:
+ Container add(Container, std::string const&);
+ // Be careful when registering Tensors that are not variables
+ Variable& add(Variable, std::string const&);
+};
+
+template <class Derived>
+class Container_CRTP : public ContainerImpl {
+ public:
+ std::shared_ptr<Derived> make() const {
+ auto ptr = std::make_shared<Derived>(*static_cast<const Derived*>(this));
+ ptr->initialize_containers();
+ ptr->initialize_parameters();
+ ptr->reset_parameters();
+ return ptr;
+ }
+
+ Container clone() const override {
+ auto ptr = std::make_shared<Derived>(*static_cast<const Derived*>(this));
+ ptr->children_.clear();
+ ptr->params_.clear();
+ ptr->initialize_containers();
+ ptr->initialize_parameters();
+ auto newParams = ptr->parameters();
+ for (auto& param : parameters()) {
+ newParams[param.first].data().copy_(param.second.data());
+ }
+ if (cuda_) {
+ ptr->cuda();
+ } else {
+ ptr->cpu();
+ }
+ return ptr;
+ }
+};
+
+template <class Derived>
+class ContainerListImpl : public Container_CRTP<Derived> {
+ // Lets you use a container like a vector without making a new class,
+ // just for simple implementations
+ public:
+ virtual variable_list forward(variable_list) override {
+ throw std::runtime_error(
+ "ContainerList has no forward, maybe you"
+ " wanted to subclass and override this function?");
+ }
+
+ Container add(Container m) {
+ return append(m).children_.back();
+ }
+
+ ContainerListImpl<Derived>& append(Container m) {
+ children_.push_back(m);
+ ContainerImpl::add(children_.back(), std::to_string(size() - 1));
+ return *this;
+ }
+
+ Container& operator[](int index) {
+ return children_[index];
+ }
+
+ int size() {
+ return children_.size();
+ }
+
+ std::vector<Container>::iterator begin() {
+ return children_.begin();
+ }
+
+ std::vector<Container>::iterator end() {
+ return children_.end();
+ }
+
+ std::vector<Container> children_;
+};
+
+class ContainerList : public ContainerListImpl<ContainerList> {};
+
+class Sequential : public ContainerListImpl<Sequential> {
+ // Mimics nn.Sequential from pytorch.
+ public:
+ variable_list forward(variable_list input) override {
+ for (auto& container : children_) {
+ input = container->forward(input);
+ }
+ return input;
+ }
+
+ Container add(Container m, std::string name = "") {
+ return append(m, name).children_.back();
+ }
+
+ Sequential& append(Container m, std::string name = "") {
+ if (name == "") {
+ name = std::to_string(size());
+ }
+ children_.push_back(m);
+ ContainerImpl::add(children_.back(), name);
+ return *this;
+ }
+};
+
+AUTOGRAD_CONTAINER_CLASS(SimpleContainer) {
+ // Lets you use a container without making a new class,
+ // for experimental implementations
+ public:
+ virtual variable_list forward(variable_list) override {
+ throw std::runtime_error(
+ "SimpleContainer has no forward, maybe you"
+ " wanted to subclass and override this function?");
+ }
+ using ContainerImpl::add;
+};
+
+AUTOGRAD_CONTAINER_CLASS(Functional) {
+ // Lets you create a container from a function, designed for use in
+ // Sequential.
+ public:
+ Functional(std::function<variable_list(variable_list)> fun) : fun_(fun){};
+ Functional(std::function<Variable(Variable)> fun)
+ : fun_([fun](variable_list input) {
+ return variable_list({fun(input[0])});
+ }){};
+
+ variable_list forward(variable_list input) override {
+ return fun_(input);
+ };
+
+ std::function<variable_list(variable_list)> fun_;
+};
+
+AUTOGRAD_CONTAINER_CLASS(Linear) {
+ public:
+ Linear(uint32_t nin, uint32_t nout) : nin(nin), nout(nout) {}
+
+ variable_list forward(variable_list) override;
+ void reset_parameters() override;
+ void initialize_parameters() override;
+ AUTOGRAD_KWARG(Linear, bool, no_bias, false, true);
+
+ Variable weight, bias;
+ uint32_t nin, nout;
+};
+
+AUTOGRAD_CONTAINER_CLASS(Embedding) {
+ public:
+ Embedding(uint32_t num_embeddings, uint32_t embedding_dim)
+ : num_embeddings(num_embeddings), embedding_dim(embedding_dim) {}
+
+ variable_list forward(variable_list) override;
+ void reset_parameters() override;
+ void initialize_parameters() override;
+
+ Variable weight;
+ uint32_t num_embeddings, embedding_dim;
+};
+
+AUTOGRAD_CONTAINER_CLASS(Conv) {
+ private:
+ Conv(uint32_t Nd, uint32_t in_chan, uint32_t out_chan)
+ : Nd_(Nd),
+ in_channels_(in_chan),
+ out_channels_(out_chan),
+ stride_(makeTup(1, 1)),
+ padding_(makeTup(0)),
+ dilation_(makeTup(1, 1)),
+ dilated_(false),
+ output_padding_(makeTup(0)) {}
+
+ public:
+ Conv(uint32_t Nd, uint32_t in_chan, uint32_t out_chan, int ks)
+ : Conv(Nd, in_chan, out_chan) {
+ ks_ = makeTup(ks, 1);
+ }
+
+ Conv(uint32_t Nd, uint32_t in_chan, uint32_t out_chan, IntVec ks)
+ : Conv(Nd, in_chan, out_chan) {
+ ks_ = makeTup(ks);
+ }
+
+ void reset_parameters() override;
+ variable_list forward(variable_list) override;
+ void initialize_parameters() override;
+
+ template <typename T>
+ Conv& stride(T s) {
+ stride_ = makeTup(s, 1);
+ return *this;
+ }
+ template <typename T>
+ Conv& padding(T s) {
+ padding_ = makeTup(s);
+ return *this;
+ }
+ template <typename T>
+ Conv& dilation(T s) {
+ dilation_ = makeTup(s, 1);
+ return *this;
+ }
+ template <typename T>
+ Conv& output_padding(T s) {
+ output_padding_ = makeTup(s);
+ return *this;
+ }
+
+ AUTOGRAD_KWARG(Conv, bool, transposed, false, true)
+ AUTOGRAD_KWARG(Conv, bool, no_bias, false, true)
+ AUTOGRAD_KWARG(Conv, int, groups, 1, 1)
+
+ Variable weight, bias;
+ uint32_t Nd_;
+ uint32_t in_channels_;
+ uint32_t out_channels_;
+ IntVec ks_;
+ IntVec stride_;
+ IntVec padding_;
+ IntVec dilation_;
+ bool dilated_;
+ IntVec output_padding_;
+
+ protected:
+ IntVec makeTup(int x, int def = 0) {
+ IntVec ret;
+ if (Nd_ == 1) {
+ ret.push_back(x);
+ ret.push_back(def);
+ } else {
+ for (auto i = 0U; i < Nd_; i++)
+ ret.push_back(x);
+ }
+ return ret;
+ }
+ IntVec makeTup(IntVec x) {
+ return x;
+ }
+};
+
+class Conv1d : public Conv {
+ public:
+ Conv1d(uint32_t i, uint32_t o, int ks) : Conv(1, i, o, ks) {}
+ Conv1d(uint32_t i, uint32_t o, IntVec ks) : Conv(1, i, o, ks) {}
+};
+
+class Conv2d : public Conv {
+ public:
+ Conv2d(uint32_t i, uint32_t o, int ks) : Conv(2, i, o, ks) {}
+ Conv2d(uint32_t i, uint32_t o, IntVec ks) : Conv(2, i, o, ks) {}
+};
+
+class Conv3d : public Conv {
+ public:
+ Conv3d(uint32_t i, uint32_t o, int ks) : Conv(3, i, o, ks) {}
+ Conv3d(uint32_t i, uint32_t o, IntVec ks) : Conv(3, i, o, ks) {}
+};
+
+AUTOGRAD_CONTAINER_CLASS(BatchNorm) {
+ public:
+ BatchNorm(uint32_t num_features) : num_features_(num_features) {}
+
+ AUTOGRAD_KWARG(BatchNorm, double, eps, 1e-5, 1e-5)
+ AUTOGRAD_KWARG(BatchNorm, double, momentum, 0.1, 0.1)
+ AUTOGRAD_KWARG(BatchNorm, bool, affine, true, true)
+ AUTOGRAD_KWARG(BatchNorm, bool, stateful, false, true)
+
+ void reset_parameters() override;
+ variable_list forward(variable_list) override;
+ void initialize_parameters() override;
+
+ Variable weight;
+ Variable bias;
+ Variable running_mean;
+ Variable running_var;
+
+ protected:
+ uint32_t num_features_;
+};
+
+AUTOGRAD_CONTAINER_CLASS(Dropout) {
+ public:
+ Dropout(double p = 0.5) : p_(p) {
+ assert(p < 1 && p >= 0);
+ }
+ variable_list forward(variable_list) override;
+
+ protected:
+ double p_;
+};
+
+AUTOGRAD_CONTAINER_CLASS(Dropout2d) {
+ public:
+ Dropout2d(double p = 0.5) : p_(p) {
+ assert(p < 1 && p >= 0);
+ }
+ variable_list forward(variable_list) override;
+
+ protected:
+ double p_;
+};
+
+template <typename Derived>
+class RNNBase : public Container_CRTP<Derived> {
+ public:
+ // These must line up with the CUDNN mode codes
+ enum RNNMode : int64_t { RNN_RELU = 0, RNN_TANH = 1, LSTM = 2, GRU = 3 };
+ RNNBase(uint32_t input_size, uint32_t hidden_size)
+ : input_size_(input_size), hidden_size_(hidden_size) {}
+
+ AUTOGRAD_KWARG(RNNBase, RNNMode, mode, RNNMode::LSTM, RNNMode::LSTM)
+ AUTOGRAD_KWARG(RNNBase, uint32_t, nlayers, 1, 1);
+ AUTOGRAD_KWARG(RNNBase, bool, no_bias, false, true)
+ AUTOGRAD_KWARG(RNNBase, float, dropout, 0, 0)
+
+ bool flatten_parameters(); // Flatten for cudnn
+
+ variable_list forward(variable_list) override;
+ void initialize_containers() override;
+ void reset_parameters() override;
+
+ void cpu() override;
+ void cuda() override;
+
+ std::vector<Container> i2h;
+ std::vector<Container> h2h;
+
+ protected:
+ uint32_t input_size_;
+ uint32_t hidden_size_;
+ uint32_t gate_size_;
+ // This is copied from pytorch, to determine whether weights are flat for
+ // the fast CUDNN route. Otherwise, we have to use non flattened weights,
+ // which
+ // are much slower.
+ // https://github.com/pytorch/pytorch/blob/1848cad10802db9fa0aa066d9de195958120d863/torch/nn/modules/rnn.py#L159-L165
+ // TODO Actually since we are in C++ we can probably just actually check if
+ // the parameters are flat, instead of relying on data pointers and stuff.
+ std::vector<void*> data_ptrs_;
+ Variable flat_weight_;
+ Container dropout_module;
+
+ variable_list CUDNN_forward(variable_list);
+ variable_list autograd_forward(variable_list);
+
+ variable_list cell_forward(variable_list, int);
+ variable_list LSTM_cell_forward(variable_list, int);
+ variable_list GRU_cell_forward(variable_list, int);
+ variable_list RNN_RELU_cell_forward(variable_list, int);
+ variable_list RNN_TANH_cell_forward(variable_list, int);
+};
+
+// We must instantiate these templates so we can put implementations in the .cpp
+class LSTM;
+template class RNNBase<LSTM>;
+class LSTM : public RNNBase<LSTM> {
+ public:
+ LSTM(uint32_t inp_size, uint32_t hid_size) : RNNBase(inp_size, hid_size) {
+ mode_ = RNNBase::RNNMode::LSTM;
+ }
+};
+
+class GRU;
+template class RNNBase<GRU>;
+class GRU : public RNNBase<GRU> {
+ public:
+ GRU(uint32_t inp_size, uint32_t hid_size) : RNNBase(inp_size, hid_size) {
+ mode_ = RNNBase::RNNMode::GRU;
+ }
+};
+
+class RNN;
+template class RNNBase<RNN>;
+class RNN : public RNNBase<RNN> {
+ public:
+ enum Mode { Tanh, Relu };
+ RNN(uint32_t inp_size, uint32_t hid_size, Mode mode = Mode::Tanh)
+ : RNNBase(inp_size, hid_size) {
+ if (mode == Mode::Tanh) {
+ mode_ = RNNBase::RNNMode::RNN_TANH;
+ } else if (mode == Mode::Relu) {
+ mode_ = RNNBase::RNNMode::RNN_RELU;
+ } else {
+ throw std::runtime_error("RNN Mode not supported");
+ }
+ }
+};
+
+} // namespace autograd
diff --git a/torch/csrc/api/include/torch/detail.h b/torch/csrc/api/include/torch/detail.h
new file mode 100644
index 0000000..487a7c8
--- /dev/null
+++ b/torch/csrc/api/include/torch/detail.h
@@ -0,0 +1,70 @@
+#pragma once
+
+#include <map>
+#include <memory>
+
+#include "torch/csrc/autograd/engine.h"
+#include "torch/csrc/autograd/grad_mode.h"
+
+// for AutoGPU. Usage:
+// AutoGPU gpu_raii(1);
+// While this object is in scope, all of your GPU tensors will go to GPU 1
+#include "torch/csrc/utils/auto_gpu.h"
+
+#define AUTOGRAD_OPTIMIZER_CLASS(Type) \
+ class Type : public autograd::Optimizer_CRTP<Type>
+#define AUTOGRAD_KWARG(CLS, TYP, NAME, DEFAULT, OPTION) \
+ TYP NAME##_ = DEFAULT; \
+ CLS& NAME(TYP x = OPTION) { \
+ NAME##_ = x; \
+ return *this; \
+ }
+
+namespace {
+namespace tag = torch::autograd;
+using IntVec = decltype(std::declval<at::IntList>().vec());
+} // namespace
+
+namespace autograd {
+namespace detail {
+extern tag::Engine engine;
+}
+
+class ContainerImpl;
+class OptimizerImpl;
+using Variable = tag::Variable;
+using variable_list = tag::variable_list;
+using Tensor = at::Tensor;
+using Container = std::shared_ptr<ContainerImpl>;
+using Optimizer = std::shared_ptr<OptimizerImpl>;
+
+void backward(Tensor loss, bool keep_graph = false);
+
+inline Variable Var(at::Tensor data, bool requires_grad = true) {
+ return tag::make_variable(data, requires_grad);
+}
+
+// This is thread local!!!
+inline void set_grad_enabled(bool val = true) {
+ tag::GradMode::set_enabled(val);
+}
+
+// RAII thread local lock that stops future execution from building gradients
+class no_grad_guard {
+ public:
+ no_grad_guard() {
+ tag::GradMode::set_enabled(false);
+ }
+
+ ~no_grad_guard() {
+ tag::GradMode::set_enabled(true);
+ }
+};
+
+void setSeed(uint64_t seed);
+
+int getNumGPUs();
+bool hasCuda();
+bool hasCudnn();
+
+} // namespace autograd
diff --git a/torch/csrc/api/include/torch/optimizers.h b/torch/csrc/api/include/torch/optimizers.h
new file mode 100644
index 0000000..db132b4
--- /dev/null
+++ b/torch/csrc/api/include/torch/optimizers.h
@@ -0,0 +1,139 @@
+#pragma once
+
+#include "torch/containers.h"
+#include "torch/detail.h"
+
+#include "cereal/access.hpp"
+#include "cereal/cereal.hpp"
+
+namespace autograd {
+class OptimizerImpl {
+ public:
+ OptimizerImpl(Container model) : model_(model) {}
+ virtual void init_state() {}
+ virtual void step() = 0;
+ void zero_grad();
+
+ void set_model(Container model);
+
+ protected:
+ OptimizerImpl() {}
+ Container model_;
+};
+
+template <class Derived>
+class Optimizer_CRTP : public OptimizerImpl {
+ public:
+ Optimizer_CRTP(Container model) : OptimizerImpl(model) {}
+ std::shared_ptr<Derived> make() const {
+ auto ptr = std::make_shared<Derived>(*static_cast<const Derived*>(this));
+ ptr->init_state();
+ return ptr;
+ }
+
+ protected:
+ Optimizer_CRTP() {}
+};
+
+AUTOGRAD_OPTIMIZER_CLASS(SGD) {
+ public:
+ SGD(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
+ AUTOGRAD_KWARG(SGD, double, momentum, 0, 0);
+ AUTOGRAD_KWARG(SGD, double, dampening, 0, 0);
+ AUTOGRAD_KWARG(SGD, double, weight_decay, 0, 0);
+ AUTOGRAD_KWARG(SGD, bool, nesterov, false, true);
+ double lr_;
+ void step() override;
+ void init_state() override;
+
+ template <class Archive>
+ void serialize(Archive & ar) {
+ ar(CEREAL_NVP(momentum_buffers_));
+ }
+
+ private:
+ friend class cereal::access;
+ SGD() {}
+ std::unordered_map<std::string, at::Tensor> momentum_buffers_;
+};
+
+AUTOGRAD_OPTIMIZER_CLASS(Adagrad) {
+ public:
+ Adagrad(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
+ AUTOGRAD_KWARG(Adagrad, double, lr_decay, 0, 0);
+ AUTOGRAD_KWARG(Adagrad, double, weight_decay, 0, 0);
+ double lr_;
+ void step() override;
+ void init_state() override;
+
+ template <class Archive>
+ void serialize(Archive & ar) {
+ ar(CEREAL_NVP(sum_));
+ ar(CEREAL_NVP(step_));
+ }
+
+ private:
+ friend class cereal::access;
+ Adagrad() {}
+ std::unordered_map<std::string, at::Tensor> sum_;
+ std::unordered_map<std::string, double> step_;
+};
+
+AUTOGRAD_OPTIMIZER_CLASS(RMSprop) {
+ public:
+ RMSprop(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
+ AUTOGRAD_KWARG(RMSprop, double, alpha, 0.99, 0.99);
+ AUTOGRAD_KWARG(RMSprop, double, eps, 1e-8, 1e-8);
+ AUTOGRAD_KWARG(RMSprop, double, weight_decay, 0, 0);
+ AUTOGRAD_KWARG(RMSprop, double, momentum, 0, 0);
+ AUTOGRAD_KWARG(RMSprop, bool, centered, false, true);
+
+ double lr_;
+ void step() override;
+ void init_state() override;
+
+ template <class Archive>
+ void serialize(Archive & ar) {
+ ar(CEREAL_NVP(square_avg_buffer_));
+ ar(CEREAL_NVP(momentum_buffer_));
+ ar(CEREAL_NVP(grad_avg_buffer_));
+ }
+
+ private:
+ friend class cereal::access;
+ RMSprop() {}
+ std::unordered_map<std::string, at::Tensor> square_avg_buffer_;
+ std::unordered_map<std::string, at::Tensor> momentum_buffer_;
+ std::unordered_map<std::string, at::Tensor> grad_avg_buffer_;
+};
+
+AUTOGRAD_OPTIMIZER_CLASS(Adam) {
+ public:
+ Adam(Container model, double lr) : Optimizer_CRTP(model), lr_(lr) {}
+ AUTOGRAD_KWARG(Adam, double, beta1, 0.9, 0.9);
+ AUTOGRAD_KWARG(Adam, double, beta2, 0.999, 0.999);
+ AUTOGRAD_KWARG(Adam, double, weight_decay, 0, 0);
+ AUTOGRAD_KWARG(Adam, double, eps, 1e-8, 1e-8);
+ AUTOGRAD_KWARG(Adam, bool, amsgrad, false, true);
+ double lr_;
+ void step() override;
+ void init_state() override;
+
+ template <class Archive>
+ void serialize(Archive & ar) {
+ ar(CEREAL_NVP(step_buffer_),
+ CEREAL_NVP(exp_avg_buffer_),
+ CEREAL_NVP(exp_avg_sq_buffer_),
+ CEREAL_NVP(max_exp_avg_sq_buffer_));
+ }
+
+ private:
+ friend class cereal::access;
+ Adam() {}
+ std::unordered_map<std::string, int> step_buffer_;
+ std::unordered_map<std::string, at::Tensor> exp_avg_buffer_;
+ std::unordered_map<std::string, at::Tensor> exp_avg_sq_buffer_;
+ std::unordered_map<std::string, at::Tensor> max_exp_avg_sq_buffer_;
+};
+
+} // namespace autograd
diff --git a/torch/csrc/api/include/torch/serialization.h b/torch/csrc/api/include/torch/serialization.h
new file mode 100644
index 0000000..1a253ae
--- /dev/null
+++ b/torch/csrc/api/include/torch/serialization.h
@@ -0,0 +1,236 @@
+#pragma once
+
+#include <fstream>
+
+#include "cereal/archives/binary.hpp"
+#include "cereal/types/polymorphic.hpp"
+
+#include "cereal/types/string.hpp"
+#include "cereal/types/unordered_map.hpp"
+#include "cereal/types/vector.hpp"
+
+namespace autograd {
+
+// Some convenience functions for saving and loading
+template <typename T>
+void save(std::ostream& stream, T const& obj) {
+ cereal::BinaryOutputArchive archive(stream);
+ archive(*obj);
+}
+template <typename T>
+void load(std::istream& stream, T& obj) {
+ cereal::BinaryInputArchive archive(stream);
+ archive(*obj);
+}
+template <typename T>
+void save(std::ostream& stream, T const* obj) {
+ cereal::BinaryOutputArchive archive(stream);
+ archive(*obj);
+}
+template <typename T>
+void load(std::istream& stream, T* obj) {
+ cereal::BinaryInputArchive archive(stream);
+ archive(*obj);
+}
+template <typename T>
+void save(std::string const& path, T const& obj) {
+ std::ofstream os(path, std::ios::binary);
+ autograd::save(os, obj);
+}
+template <typename T>
+void load(std::string const& path, T& obj) {
+ std::ifstream is(path, std::ios::binary);
+ autograd::load(is, obj);
+}
+
+namespace detail {
+
+// We use our own hard-coded type<->id mapping so that serialization is robust
+// wrt changes in ATen; see e.g. https://git.io/vxd6R
+// The mapping is consistent with the ScalarType enum as of pytorch version
+// v0.1.11-7675-ge94c67e.
+inline int32_t scalarTypeId(at::ScalarType type) {
+ switch (type) {
+ case at::ScalarType::Byte: return 0;
+ case at::ScalarType::Char: return 1;
+ case at::ScalarType::Short: return 2;
+ case at::ScalarType::Int: return 3;
+ case at::ScalarType::Long: return 4;
+ case at::ScalarType::Half: return 5;
+ case at::ScalarType::Float: return 6;
+ case at::ScalarType::Double: return 7;
+ case at::ScalarType::Undefined: return 8;
+ default:
+ throw std::runtime_error(
+ "Unknown scalar type: " + std::to_string(static_cast<int>(type)));
+ }
+}
+
+inline at::ScalarType scalarTypeFromId(int32_t id) {
+ switch (id) {
+ case 0: return at::ScalarType::Byte;
+ case 1: return at::ScalarType::Char;
+ case 2: return at::ScalarType::Short;
+ case 3: return at::ScalarType::Int;
+ case 4: return at::ScalarType::Long;
+ case 5: return at::ScalarType::Half;
+ case 6: return at::ScalarType::Float;
+ case 7: return at::ScalarType::Double;
+ case 8: return at::ScalarType::Undefined;
+ default:
+ throw std::runtime_error("Unknown scalar type id: " + std::to_string(id));
+ }
+}
+
+inline int32_t backendId(at::Backend backend) {
+ switch (backend) {
+ case at::Backend::CPU: return 0;
+ case at::Backend::CUDA: return 1;
+ case at::Backend::SparseCPU: return 2;
+ case at::Backend::SparseCUDA: return 3;
+ case at::Backend::Undefined: return 4;
+ default:
+ throw std::runtime_error(
+ "Unknown backend: " + std::to_string(static_cast<int>(backend)));
+ }
+}
+
+inline at::Backend backendFromId(int32_t id) {
+ switch (id) {
+ case 0: return at::Backend::CPU;
+ case 1: return at::Backend::CUDA;
+ case 2: return at::Backend::SparseCPU;
+ case 3: return at::Backend::SparseCUDA;
+ case 4: return at::Backend::Undefined;
+ default:
+ throw std::runtime_error("Unknown backend id: " + std::to_string(id));
+ }
+}
+
+} // namespace detail
+} // namespace autograd
+
+// This is super ugly and I don't know how to simplify it
+CEREAL_REGISTER_TYPE(autograd::SGD);
+CEREAL_REGISTER_POLYMORPHIC_RELATION(autograd::OptimizerImpl, autograd::SGD);
+CEREAL_REGISTER_TYPE(autograd::Adagrad);
+CEREAL_REGISTER_POLYMORPHIC_RELATION(
+ autograd::OptimizerImpl,
+ autograd::Adagrad);
+CEREAL_REGISTER_TYPE(autograd::RMSprop);
+CEREAL_REGISTER_POLYMORPHIC_RELATION(
+ autograd::OptimizerImpl,
+ autograd::RMSprop);
+CEREAL_REGISTER_TYPE(autograd::Adam);
+CEREAL_REGISTER_POLYMORPHIC_RELATION(autograd::OptimizerImpl, autograd::Adam);
+
+namespace cereal {
+
+namespace agimpl {
+
+template <class Archive>
+void saveBinary(Archive& archive, void const* data, std::size_t size) {
+ // In general, there's no direct `saveBinary`-like method on archives
+ std::vector<char> v(
+ static_cast<char const*>(data), static_cast<char const*>(data) + size);
+ archive(v);
+}
+template <>
+inline void
+saveBinary(BinaryOutputArchive& archive, void const* data, std::size_t size) {
+ // Writes to output stream without extra copy
+ archive.saveBinary(data, size);
+}
+
+template <class Archive>
+void loadBinary(Archive& archive, void* data, std::size_t size) {
+ // In general, there's no direct `loadBinary`-like method on archives
+ std::vector<char> v(size);
+ archive(v);
+ std::memcpy(data, v.data(), size);
+}
+template <>
+inline void
+loadBinary(BinaryInputArchive& archive, void* data, std::size_t size) {
+ // Read from input stream without extra copy
+ archive.loadBinary(data, size);
+}
+
+} // namespace agimpl
+
+// Gradients will not be saved for variables
+template <class Archive>
+void save(Archive& archive, at::Tensor const& tensor) {
+ if (!tensor.defined()) {
+ int32_t typeId = ::autograd::detail::scalarTypeId(at::ScalarType::Undefined);
+ archive(CEREAL_NVP(typeId));
+ return;
+ } else {
+ int32_t typeId = ::autograd::detail::scalarTypeId(tensor.type().scalarType());
+ archive(CEREAL_NVP(typeId));
+ }
+ auto sizes = std::vector<int64_t>();
+ auto buf = std::vector<uint8_t>();
+ for (auto s : tensor.sizes()) {
+ sizes.push_back(s);
+ }
+ auto contig = tensor.toBackend(at::kCPU).contiguous();
+ int32_t backend = ::autograd::detail::backendId(tensor.type().backend());
+
+ archive(CEREAL_NVP(backend), CEREAL_NVP(sizes));
+ agimpl::saveBinary(
+ archive,
+ contig.data_ptr(),
+ tensor.numel() * tensor.type().elementSizeInBytes());
+}
+
+/**
+ * We follow these rules for loading:
+ * 1. If tensor is defined, and the same ScalarType as the saved tensor,
+ * then we simply copy the data into the tensor, with resizing.
+ * 2. Otherwise, overwrite the provided tensor with the right type and backend
+ **/
+template <class Archive>
+void load(Archive& archive, at::Tensor& tensor) {
+ at::ScalarType type;
+ int32_t typeId;
+ archive(CEREAL_NVP(typeId));
+ type = ::autograd::detail::scalarTypeFromId(typeId);
+ if (type == at::ScalarType::Undefined) {
+ tensor = at::Tensor();
+ return;
+ }
+
+ int32_t backendId;
+ auto sizes = std::vector<int64_t>();
+ auto buf = std::vector<uint8_t>();
+ archive(CEREAL_NVP(backendId), CEREAL_NVP(sizes));
+
+ at::Backend backend = ::autograd::detail::backendFromId(backendId);
+ if (!tensor.defined() || tensor.type().scalarType() != type) {
+ tensor = at::getType(backend, type).tensor();
+ }
+ tensor.resize_(sizes);
+
+ if (tensor.type().is_cuda()) {
+ // should actually use cudamemcpy probably
+ auto cputensor = at::CPU(tensor.type().scalarType()).tensor(sizes);
+ agimpl::loadBinary(
+ archive,
+ cputensor.data_ptr(),
+ cputensor.numel() * cputensor.type().elementSizeInBytes());
+ tensor.copy_(cputensor);
+ } else {
+ agimpl::loadBinary(
+ archive,
+ tensor.data_ptr(),
+ tensor.numel() * tensor.type().elementSizeInBytes());
+ }
+}
+
+template <class Archive>
+void load(Archive& archive, tag::Variable& var) {
+ load(archive, var.data());
+}
+
+} // namespace cereal
diff --git a/torch/csrc/api/src/containers.cpp b/torch/csrc/api/src/containers.cpp
new file mode 100644
index 0000000..78ef5b7
--- /dev/null
+++ b/torch/csrc/api/src/containers.cpp
@@ -0,0 +1,633 @@
+#include "torch/containers.h"
+
+namespace autograd {
+std::map<std::string, Variable> ContainerImpl::parameters() const {
+ std::map<std::string, Variable> ret;
+ for (auto pair : children_) {
+ auto& name = pair.first;
+ auto& child = pair.second;
+ for (auto& p : child->parameters()) {
+ ret[name + "." + p.first] = p.second;
+ }
+ }
+ for (auto pair : params_) {
+ ret[pair.first] = pair.second;
+ }
+ return ret;
+}
+
+Variable& ContainerImpl::param(std::string const& name) {
+ ContainerImpl* container = this;
+ auto begin = 0;
+ while (true) {
+ auto dot_pos = name.find('.', begin);
+ if (dot_pos == std::string::npos) {
+ break;
+ }
+
+ auto child_name = name.substr(begin, dot_pos - begin);
+ auto it = container->children_.find(child_name);
+ if (it == container->children_.end()) {
+ throw std::runtime_error("No such child: " + child_name);
+ }
+
+ container = it->second.get();
+ begin = dot_pos + 1; // Skip the dot
+ }
+
+ auto param_name = name.substr(begin);
+ auto it = container->params_.find(param_name);
+ if (it == params_.end()) {
+ throw std::runtime_error("No such param: " + param_name);
+ }
+ return it->second;
+}
+
+void ContainerImpl::cuda() {
+ for (auto& pair : children_) {
+ pair.second->cuda();
+ }
+ cuda_ = true;
+ auto copied = params_;
+ params_.clear();
+ initialize_parameters();
+ for (auto pair : params_) {
+ pair.second.data().copy_(copied[pair.first].data());
+ }
+};
+
+void ContainerImpl::cpu() {
+ for (auto& pair : children_) {
+ pair.second->cpu();
+ }
+ cuda_ = false;
+ auto copied = params_;
+ params_.clear();
+ initialize_parameters();
+ for (auto pair : params_) {
+ pair.second.data().copy_(copied[pair.first].data());
+ }
+};
+
+void ContainerImpl::train() {
+ for (auto& pair : children_) {
+ pair.second->train();
+ }
+ train_ = true;
+}
+
+void ContainerImpl::eval() {
+ for (auto& pair : children_) {
+ pair.second->eval();
+ }
+ train_ = false;
+}
+
+Container ContainerImpl::add(Container m, std::string const& name) {
+ if (this->children_.find(name) != this->children_.end()) {
+ throw std::runtime_error("Trying to add container that already exists");
+ }
+ if (std::find(name.begin(), name.end(), '.') != name.end()) {
+ // We can't allow containers with dots in their names, as that would make
+ // their parameters not findable with parameters().
+ throw std::runtime_error("Trying to add parameter with a '.' in its name");
+ }
+ this->children_[name] = std::move(m);
+ return this->children_[name];
+}
+
+Variable& ContainerImpl::add(Variable v, std::string const& name) {
+ if (this->params_.find(name) != this->params_.end()) {
+ throw std::runtime_error("Trying to add parameter that already exists");
+ }
+ if (std::find(name.begin(), name.end(), '.') != name.end()) {
+ // We can't allow parameters with dots in their names, as that would make
+ // them not findable with parameters().
+ throw std::runtime_error("Trying to add parameter with a '.' in its name");
+ }
+ this->params_[name] = v;
+ return this->params_[name];
+}
+
+at::Type& ContainerImpl::DefaultTensor(at::ScalarType s) {
+ if (cuda_)
+ return at::CUDA(s);
+ else
+ return at::CPU(s);
+}
+
+variable_list Linear::forward(variable_list input) {
+ auto x = input[0];
+ if (x.ndimension() == 2 && !no_bias_) {
+ // Fused op is marginally faster
+ assert(x.size(1) == weight.size(1));
+ return variable_list({at::addmm(bias, x, weight.t())});
+ }
+
+ auto output = x.matmul(weight.t());
+ if (!no_bias_) {
+ output += bias;
+ }
+ return variable_list({output});
+}
+
+void Linear::reset_parameters() {
+ auto stdv = 1.0 / std::sqrt(weight.size(1));
+ for (auto& p : parameters()) {
+ p.second.data().uniform_(-stdv, stdv);
+ }
+}
+
+void Linear::initialize_parameters() {
+ weight = this->add(
+ Var(DefaultTensor(at::kFloat).tensor({nout, nin}), true), "weight");
+ if (!no_bias_) {
+ bias =
+ this->add(Var(DefaultTensor(at::kFloat).tensor({nout}), true), "bias");
+ }
+}
+
+variable_list Embedding::forward(variable_list input) {
+ auto x = input[0];
+ return variable_list({at::embedding(weight, x, -1, false, false)});
+}
+
+void Embedding::reset_parameters() {
+ for (auto& p : parameters()) {
+ p.second.data().normal_(0, 1);
+ }
+}
+
+void Embedding::initialize_parameters() {
+ weight = this->add(
+ Var(DefaultTensor(at::kFloat).tensor({num_embeddings, embedding_dim}),
+ true),
+ "weight");
+}
+
+void Conv::initialize_parameters() {
+ if (!transposed_) {
+ for (auto pad : output_padding_) {
+ if (pad != 0) {
+ throw std::runtime_error(
+ "Only transposed convolutions support output padding!");
+ }
+ }
+ }
+
+ IntVec wsize;
+ if (transposed_) {
+ wsize.push_back(in_channels_);
+ wsize.push_back(out_channels_ / groups_);
+ } else {
+ wsize.push_back(out_channels_);
+ wsize.push_back(in_channels_ / groups_);
+ }
+ wsize.insert(wsize.end(), ks_.begin(), ks_.end());
+ weight =
+ this->add(Var(DefaultTensor(at::kFloat).tensor(wsize), true), "weight");
+ if (!no_bias_) {
+ bias = this->add(
+ Var(DefaultTensor(at::kFloat).tensor({out_channels_}), true), "bias");
+ } else {
+ assert(!bias.defined());
+ }
+}
+
+void Conv::reset_parameters() {
+ auto n = in_channels_;
+ for (auto k : ks_)
+ n *= k;
+ auto stdv = 1.0 / std::sqrt(n);
+ for (auto& p : parameters()) {
+ p.second.data().uniform_(-stdv, stdv);
+ }
+}
+
+variable_list Conv::forward(variable_list input) {
+ auto x = input[0];
+ if (Nd_ == 1) {
+ assert(x.ndimension() == 3);
+ x = x.unsqueeze(-1); // TODO: Use conv1d once available
+ } else if (Nd_ == 2) {
+ assert(x.ndimension() == 4);
+ } else if (Nd_ == 3) {
+ assert(x.ndimension() == 5);
+ } else {
+ throw std::runtime_error("Only Conv{1,2,3}d are supported");
+ }
+
+ Variable out;
+ if (Nd_ == 1 || Nd_ == 2) {
+ if (transposed_) {
+ out = at::conv_transpose2d(
+ x,
+ weight,
+ bias,
+ stride_,
+ padding_,
+ output_padding_,
+ groups_,
+ dilation_);
+ } else {
+ out = at::conv2d(x, weight, bias, stride_, padding_, dilation_, groups_);
+ }
+ } else if (Nd_ == 3) {
+ if (transposed_) {
+ out = at::conv_transpose3d(
+ x,
+ weight,
+ bias,
+ stride_,
+ padding_,
+ output_padding_,
+ groups_,
+ dilation_);
+ } else {
+ out = at::conv3d(x, weight, bias, stride_, padding_, dilation_, groups_);
+ }
+ }
+
+ return variable_list({out});
+}
+
+void BatchNorm::initialize_parameters() {
+ if (affine_) {
+ weight = this->add(
+ Var(DefaultTensor(at::kFloat).tensor(num_features_), true), "weight");
+ bias = this->add(
+ Var(DefaultTensor(at::kFloat).tensor(num_features_), true), "bias");
+ }
+
+ if (stateful_) {
+ running_mean = Var(DefaultTensor(at::kFloat).zeros({num_features_}), false);
+ running_var = Var(DefaultTensor(at::kFloat).ones({num_features_}), false);
+ }
+}
+
+void BatchNorm::reset_parameters() {
+ if (affine_) {
+ weight.data().uniform_();
+ bias.data().zero_();
+ }
+
+ if (stateful_) {
+ running_mean.data().zero_();
+ running_var.data().fill_(1);
+ }
+}
+
+variable_list BatchNorm::forward(variable_list inputs) {
+ auto& input = inputs[0];
+ auto& running_mean = (stateful_ ? this->running_mean : inputs[1]);
+ auto& running_var = (stateful_ ? this->running_var : inputs[2]);
+
+ if (train_) {
+ const auto num_channels = input.dim() > 1 ? input.size(1) : 1;
+ if (input.numel() / num_channels <= 1) {
+ throw std::runtime_error(
+ "BatchNorm expected more than 1 value per channel when training!");
+ }
+ }
+
+ auto output = at::batch_norm(
+ input,
+ weight,
+ bias,
+ running_mean,
+ running_var,
+ train_,
+ momentum_,
+ eps_,
+ hasCudnn());
+
+ return variable_list({output});
+}
+
+template <typename Derived>
+void RNNBase<Derived>::initialize_containers() {
+ auto gate_size = hidden_size_;
+ if (mode_ == RNNMode::LSTM) {
+ gate_size *= 4;
+ } else if (mode_ == RNNMode::GRU) {
+ gate_size *= 3;
+ }
+
+ for (auto i = 0U; i < nlayers_; i++) {
+ auto input_size = (i == 0) ? input_size_ : hidden_size_;
+ i2h.push_back(this->add(
+ Linear(input_size, gate_size).no_bias(no_bias_).make(),
+ "i2h_" + std::to_string(i)));
+ h2h.push_back(this->add(
+ Linear(hidden_size_, gate_size).no_bias(no_bias_).make(),
+ "h2h_" + std::to_string(i)));
+ }
+ if (dropout_ > 0)
+ dropout_module = Dropout(dropout_).make();
+ this->flatten_parameters();
+}
+
+template <typename Derived>
+void RNNBase<Derived>::reset_parameters() {
+ auto stdv = 1.0 / std::sqrt(hidden_size_);
+ for (auto& p : this->parameters()) {
+ p.second.data().uniform_(-stdv, stdv);
+ }
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::GRU_cell_forward(variable_list inputs, int i) {
+ auto x = inputs[0];
+ auto hx = inputs[1].defined()
+ ? inputs[1]
+ : Var(this->DefaultTensor(at::kFloat).zeros({x.size(0), hidden_size_}));
+
+ auto gi = i2h[i]->forward({x})[0];
+ auto gh = h2h[i]->forward({hx})[0];
+ auto gic = gi.chunk(3, 1);
+ auto ghc = gh.chunk(3, 1);
+
+ auto reset_gate = (gic[0] + ghc[0]).sigmoid_();
+ auto input_gate = (gic[1] + ghc[1]).sigmoid_();
+ auto new_gate = (gic[2] + reset_gate * ghc[2]).tanh_();
+ auto hy = new_gate + input_gate * (hx - new_gate);
+
+ return variable_list({hy});
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::RNN_TANH_cell_forward(
+ variable_list inputs,
+ int i) {
+ auto x = inputs[0];
+ auto hx = inputs[1].defined()
+ ? inputs[1]
+ : Var(this->DefaultTensor(at::kFloat).zeros({x.size(0), hidden_size_}));
+
+ auto h = (i2h[i]->forward({x})[0] + h2h[i]->forward({hx})[0]).tanh();
+ return variable_list({h});
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::RNN_RELU_cell_forward(
+ variable_list inputs,
+ int i) {
+ auto x = inputs[0];
+ auto hx = inputs[1].defined()
+ ? inputs[1]
+ : Var(this->DefaultTensor(at::kFloat).zeros({x.size(0), hidden_size_}));
+
+ auto h = (i2h[i]->forward({x})[0] + h2h[i]->forward({hx})[0]).clamp_min(0);
+ return variable_list({h});
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::LSTM_cell_forward(variable_list inputs, int i) {
+ auto x = inputs[0];
+ auto hid = inputs[1].defined()
+ ? inputs[1]
+ : Var(this->DefaultTensor(at::kFloat)
+ .zeros({2, x.size(0), hidden_size_}));
+ auto hx = hid[0];
+ auto cx = hid[1];
+
+ auto gates = i2h[i]->forward({x})[0] + h2h[i]->forward({hx})[0];
+
+ auto chunked = gates.chunk(4, 1);
+ auto in_gate = chunked[0].sigmoid();
+ auto forget_gate = chunked[1].sigmoid();
+ auto cell_gate = chunked[2].tanh();
+ auto out_gate = chunked[3].sigmoid();
+
+ auto cy = (forget_gate * cx) + (in_gate * cell_gate);
+ auto hy = out_gate * cy.tanh();
+
+ return variable_list({at::stack({hy, cy}, 0)});
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::cell_forward(variable_list inputs, int i) {
+ if (mode_ == RNNMode::LSTM)
+ return LSTM_cell_forward(inputs, i);
+ else if (mode_ == RNNMode::GRU)
+ return GRU_cell_forward(inputs, i);
+ else if (mode_ == RNNMode::RNN_TANH)
+ return RNN_TANH_cell_forward(inputs, i);
+ else if (mode_ == RNNMode::RNN_RELU)
+ return RNN_RELU_cell_forward(inputs, i);
+ else
+ throw std::runtime_error("No such RNN mode");
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::autograd_forward(variable_list inputs) {
+ auto inp = inputs[0];
+
+ std::vector<Tensor> hidden;
+ for (size_t i = 0; i < nlayers_; i++) {
+ hidden.push_back(inputs[1].defined() ? inputs[1][i] : tag::Variable());
+ }
+
+ auto output =
+ Var(this->DefaultTensor(at::kFloat)
+ .zeros({inp.size(0), inp.size(1), hidden_size_}),
+ false);
+ for (auto t = 0U; t < inp.size(0); t++) {
+ auto x = inp.select(0, t);
+ for (size_t i = 0; i < nlayers_; i++) {
+ auto layer_output = cell_forward({x, hidden[i]}, i);
+ hidden[i] = layer_output[0];
+ if (mode_ == RNNMode::LSTM) {
+ x = hidden[i][0];
+ } else {
+ x = hidden[i];
+ }
+ auto output_slice = output.select(0, t);
+ output_slice.copy_(x);
+ if (dropout_ > 0 && i != nlayers_ - 1) {
+ x = dropout_module->forward({x})[0];
+ }
+ }
+ }
+
+ auto hidout = at::stack(hidden, 0);
+ return variable_list({output, hidout});
+}
+
+template <typename Derived>
+bool RNNBase<Derived>::flatten_parameters() {
+ data_ptrs_.clear();
+ auto anyParam = i2h[0]->params_.begin()->second;
+ if (!anyParam.is_cuda() || !at::cudnn_is_acceptable(anyParam)) {
+ return false;
+ }
+ std::unordered_set<void*> unique_data_ptrs;
+ auto params = this->parameters();
+ for (auto& p : params) {
+ unique_data_ptrs.insert(p.second.data().data_ptr());
+ }
+ // TODO PyTorch says:
+ // If any parameters alias, we fall back to the slower, copying code path.
+ // This is
+ // a sufficient check, because overlapping parameter buffers that don't
+ // completely
+ // alias would break the assumptions of the uniqueness check in
+ // Module.named_parameters().
+ // But I'm not sure if this is the case for us
+ if (unique_data_ptrs.size() != params.size()) {
+ return false;
+ }
+
+ std::vector<Tensor> weight_list;
+ for (size_t i = 0; i < nlayers_; i++) {
+ weight_list.push_back(i2h[i]->param("weight"));
+ weight_list.push_back(h2h[i]->param("weight"));
+ if (!no_bias_) {
+ weight_list.push_back(i2h[i]->param("bias"));
+ weight_list.push_back(h2h[i]->param("bias"));
+ }
+ }
+ auto weight_stride0 = no_bias_ ? 2 : 4;
+
+ {
+ no_grad_guard guard;
+ flat_weight_ = at::_cudnn_rnn_flatten_weight(
+ weight_list,
+ weight_stride0,
+ input_size_,
+ mode_,
+ hidden_size_,
+ nlayers_,
+ false,
+ false); // batch_first and bidirectional, unsupported
+ }
+ for (auto& p : params) {
+ data_ptrs_.emplace_back(p.second.data().data_ptr());
+ }
+ return true;
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::CUDNN_forward(variable_list inputs) {
+ std::vector<Tensor> weight_list;
+ for (size_t i = 0; i < nlayers_; i++) {
+ weight_list.push_back(i2h[i]->param("weight"));
+ weight_list.push_back(h2h[i]->param("weight"));
+ if (!no_bias_) {
+ weight_list.push_back(i2h[i]->param("bias"));
+ weight_list.push_back(h2h[i]->param("bias"));
+ }
+ }
+ auto weight_stride0 = no_bias_ ? 2 : 4;
+
+ auto x = inputs[0];
+ Variable hx, cx;
+ if (!inputs[1].defined()) {
+ hx = x.type().zeros({nlayers_, x.size(1), hidden_size_});
+ if (mode_ == RNNMode::LSTM) {
+ cx = x.type().zeros({nlayers_, x.size(1), hidden_size_});
+ }
+ } else {
+ hx = mode_ == RNNMode::LSTM ? inputs[1][0] : inputs[1];
+ cx = mode_ == RNNMode::LSTM ? inputs[1][1] : Variable();
+ }
+ auto dropout_state = x.type().tensor();
+
+ std::vector<void*> weight_data_ptrs;
+ auto params = this->parameters();
+ for (auto& p : params) {
+ weight_data_ptrs.emplace_back(p.second.data().data_ptr());
+ }
+ if (weight_data_ptrs != data_ptrs_) {
+ std::cerr << "Parameters are unflattened! Code path might be super slow. "
+ "Please call flatten_parameters() when you muck around with "
+ "storages!"
+ << std::endl;
+ flat_weight_ = Variable();
+ }
+
+ // tup = std::tuple of output, hy, cy, reserve, new_weight_buf
+ auto tup = _cudnn_rnn(
+ x,
+ weight_list,
+ weight_stride0,
+ flat_weight_,
+ hx,
+ cx,
+ mode_,
+ hidden_size_,
+ nlayers_,
+ false, // batch first
+ 0, // TODO waiting on dropout state descriptor in C++ pytorch
+ this->train_,
+ false, // bidirectional
+ {}, // packing not supported
+ dropout_state // TODO waiting on dropout state descriptor in C++ pytorch
+ );
+
+ Variable hidout = mode_ == RNNMode::LSTM
+ ? at::stack({std::get<1>(tup), std::get<2>(tup)}, 0)
+ : std::get<1>(tup);
+ Variable output = std::get<0>(tup);
+ return variable_list({output, hidout});
+}
+
+template <typename Derived>
+variable_list RNNBase<Derived>::forward(variable_list inputs) {
+ variable_list inp;
+ inp.push_back(inputs[0]);
+ if (inputs.size() > 1) {
+ inp.push_back(inputs[1]);
+ } else {
+ inp.push_back(Variable());
+ }
+
+ // Dropout descriptors aren't in C++ in PyTorch yet...
+ auto output = at::cudnn_is_acceptable(inp[0]) && dropout_ == 0
+ ? CUDNN_forward(inp)
+ : autograd_forward(inp);
+
+ return output;
+}
+
+template <typename Derived>
+void RNNBase<Derived>::cuda() {
+ Container_CRTP<Derived>::cuda();
+ flatten_parameters();
+}
+
+template <typename Derived>
+void RNNBase<Derived>::cpu() {
+ Container_CRTP<Derived>::cpu();
+ flatten_parameters();
+}
+
+variable_list Dropout::forward(variable_list inputs) {
+ if (p_ == 0 || !this->train_)
+ return inputs;
+ variable_list lst;
+ for (auto x : inputs) {
+ auto noise = x.data().type().tensor(x.sizes());
+ noise = (noise.uniform_(0, 1) > p_)
+ .toType(x.type().scalarType())
+ .mul_(1. / (1 - p_));
+ lst.push_back(x * Var(noise));
+ }
+ return lst;
+}
+
+variable_list Dropout2d::forward(variable_list inputs) {
+ if (p_ == 0 || !this->train_)
+ return inputs;
+ variable_list lst;
+ for (auto x : inputs) {
+ auto noise = x.data().type().tensor({x.size(0), x.size(1), 1, 1});
+ noise = (noise.uniform_(0, 1) > p_)
+ .toType(x.type().scalarType())
+ .mul_(1. / (1 - p_));
+ lst.push_back(x * Var(noise));
+ }
+ return lst;
+}
+
+} // namespace autograd
diff --git a/torch/csrc/api/src/detail.cpp b/torch/csrc/api/src/detail.cpp
new file mode 100644
index 0000000..a2bd1cd
--- /dev/null
+++ b/torch/csrc/api/src/detail.cpp
@@ -0,0 +1,71 @@
+#include <ATen/Config.h>
+
+#include <algorithm>
+#include <cmath>
+#include <functional>
+#include <stdexcept>
+
+#if AT_CUDA_ENABLED()
+#include <THC/THCTensorRandom.h>
+#include <cuda.h>
+#include <cuda_runtime.h>
+#endif
+
+#include "torch/detail.h"
+
+namespace autograd {
+namespace detail {
+tag::Engine engine;
+}
+
+void backward(Variable loss, bool keep_graph) {
+ tag::edge_list edgelst;
+ tag::variable_list varlst;
+ edgelst.emplace_back(loss.grad_fn(), loss.output_nr());
+ varlst.emplace_back(Var(at::ones_like(loss.data()), false));
+ // create_graph should be set to true when we want to support double bwd
+ detail::engine.execute(edgelst, varlst, keep_graph, false);
+}
+
+void backward(Tensor loss, bool keep_graph) {
+ Variable tmp(loss);
+ backward(tmp, keep_graph);
+}
+
+void setSeed(uint64_t seed) {
+ at::globalContext().defaultGenerator(at::Backend::CPU).manualSeed(seed);
+#if AT_CUDA_ENABLED()
+ if (getNumGPUs() > 0) {
+ THCRandom_manualSeedAll(at::globalContext().lazyInitCUDA(), seed);
+ }
+#endif
+};
+
+int getNumGPUs() {
+#if AT_CUDA_ENABLED()
+ int count;
+ auto err = cudaGetDeviceCount(&count);
+ if (err == cudaErrorNoDevice) {
+ return 0;
+ } else if (err != cudaSuccess) {
+ std::string msg = "CUDA error (";
+ msg += std::to_string(err);
+ msg += "): ";
+ msg += cudaGetErrorString(err);
+ throw std::runtime_error(msg);
+ }
+ return count;
+#else
+ return 0;
+#endif
+}
+
+bool hasCuda() {
+ return getNumGPUs() > 0;
+}
+
+bool hasCudnn() {
+ return hasCuda() && AT_CUDNN_ENABLED();
+}
+
+} // namespace autograd
diff --git a/torch/csrc/api/src/optimizers.cpp b/torch/csrc/api/src/optimizers.cpp
new file mode 100644
index 0000000..2c90d8b
--- /dev/null
+++ b/torch/csrc/api/src/optimizers.cpp
@@ -0,0 +1,199 @@
+#include "torch/optimizers.h"
+
+namespace autograd {
+
+void OptimizerImpl::zero_grad() {
+ for (auto p : model_->parameters()) {
+ auto& grad = p.second.grad();
+ if (grad.defined()) {
+ grad = grad.detach();
+ torch::autograd::as_variable_ref(grad).data().zero_();
+ }
+ }
+}
+
+void OptimizerImpl::set_model(Container model) {
+ model_ = model;
+}
+
+void SGD::step() {
+ for (auto& pair : model_->parameters()) {
+ auto& name = pair.first;
+ auto& grad = pair.second.grad();
+ auto& p = pair.second.data();
+ if (!grad.defined())
+ continue;
+
+ auto d_p = torch::autograd::as_variable_ref(grad).data();
+ if (weight_decay_ > 0) {
+ d_p.add_(p, weight_decay_);
+ };
+
+ if (momentum_ != 0) {
+ at::Tensor buf;
+ if (momentum_buffers_.find(name) == momentum_buffers_.end()) {
+ buf = momentum_buffers_[name] = at::zeros_like(p);
+ buf.mul_(momentum_).add_(d_p);
+ } else {
+ buf = momentum_buffers_[name];
+ buf.mul_(momentum_).add_(d_p, 1 - dampening_);
+ }
+
+ if (nesterov_) {
+ d_p = d_p.add(buf, momentum_);
+ } else {
+ d_p = buf;
+ }
+ }
+
+ p.add_(d_p, -lr_);
+ }
+}
+
+void SGD::init_state() {
+ momentum_buffers_.clear();
+}
+
+/// Adapted from
+/// https://github.com/pytorch/pytorch/blob/master/torch/optim/adagrad.py
+void Adagrad::step() {
+ for (auto& pair : model_->parameters()) {
+ auto& name = pair.first;
+ auto& grad = pair.second.grad();
+ auto& p = pair.second.data();
+ if (!grad.defined())
+ continue;
+
+ auto d_p = torch::autograd::as_variable_ref(grad).data();
+ if (weight_decay_ > 0) {
+ d_p.add_(p, weight_decay_);
+ };
+ auto& step = step_[name];
+ step += 1.0;
+ auto clr = lr_ / (1.0 + (step - 1.0) * lr_decay_);
+ at::Tensor buf;
+ if (sum_.find(name) == sum_.end()) {
+ buf = sum_[name] = at::zeros_like(p);
+ } else {
+ buf = sum_[name];
+ }
+
+ buf.addcmul_(d_p, d_p, 1.0);
+ at::Tensor std = buf.sqrt().add_(1e-10);
+ p.addcdiv_(d_p, std, -clr);
+ }
+}
+
+void Adagrad::init_state() {
+ sum_.clear();
+ step_.clear();
+}
+
+/// Adapted from
+/// https://github.com/pytorch/pytorch/blob/master/torch/optim/rmsprop.py
+void RMSprop::step() {
+ for (auto& pair : model_->parameters()) {
+ auto& name = pair.first;
+ auto& grad = pair.second.grad();
+ auto& p = pair.second.data();
+ if (!grad.defined())
+ continue;
+
+ if (square_avg_buffer_.find(name) == square_avg_buffer_.end()) {
+ square_avg_buffer_[name] = at::zeros_like(p);
+ if (momentum_) {
+ momentum_buffer_[name] = at::zeros_like(p);
+ };
+ if (centered_) {
+ grad_avg_buffer_[name] = at::zeros_like(p);
+ };
+ };
+
+ auto d_p = torch::autograd::as_variable_ref(grad).data();
+ if (weight_decay_ > 0) {
+ d_p.add_(p, weight_decay_);
+ };
+
+ auto& square_avg = square_avg_buffer_[name];
+ square_avg.mul_(alpha_).addcmul_(d_p, d_p, 1.0 - alpha_);
+
+ at::Tensor avg;
+ if (centered_) {
+ auto& grad_avg = grad_avg_buffer_[name];
+ grad_avg.mul_(alpha_).add_(d_p, 1.0 - alpha_);
+ avg = square_avg.addcmul(grad_avg, grad_avg, -1.0).sqrt().add_(eps_);
+ } else {
+ avg = square_avg.sqrt().add_(eps_);
+ };
+
+ if (momentum_ > 0) {
+ auto& buf = momentum_buffer_[name];
+ buf.mul_(momentum_).addcdiv_(d_p, avg);
+ p.add_(buf, -lr_);
+ } else {
+ p.addcdiv_(d_p, avg, -lr_);
+ };
+ }
+}
+
+void RMSprop::init_state() {
+ square_avg_buffer_.clear();
+ momentum_buffer_.clear();
+ grad_avg_buffer_.clear();
+}
+
+void Adam::step() {
+ for (auto& pair : model_->parameters()) {
+ auto& name = pair.first;
+ auto& grad = pair.second.grad();
+ auto& p = pair.second.data();
+ if (!grad.defined())
+ continue;
+
+ if (step_buffer_.find(name) == step_buffer_.end()) {
+ step_buffer_[name] = 0;
+ exp_avg_buffer_[name] = at::zeros_like(p);
+ exp_avg_sq_buffer_[name] = at::zeros_like(p);
+ if (amsgrad_) {
+ max_exp_avg_sq_buffer_[name] = at::zeros_like(p);
+ };
+ }
+
+ auto& step = step_buffer_[name];
+ auto& exp_avg = exp_avg_buffer_[name];
+ auto& exp_avg_sq = exp_avg_sq_buffer_[name];
+
+ step += 1;
+
+ auto d_p = torch::autograd::as_variable_ref(grad).data();
+ if (weight_decay_ > 0) {
+ d_p.add_(p, weight_decay_);
+ }
+
+ exp_avg.mul_(beta1_).add_(d_p, 1 - beta1_);
+ exp_avg_sq.mul_(beta2_).addcmul_(d_p, d_p, 1 - beta2_);
+
+ at::Tensor denom;
+ if (amsgrad_) {
+ auto& max_exp_avg_sq = max_exp_avg_sq_buffer_[name];
+ at::max_out(max_exp_avg_sq, max_exp_avg_sq, exp_avg_sq);
+ denom = max_exp_avg_sq.sqrt().add_(eps_);
+ } else {
+ denom = exp_avg_sq.sqrt().add_(eps_);
+ };
+
+ auto bias_correction1 = 1 - std::pow(beta1_, step);
+ auto bias_correction2 = 1 - std::pow(beta2_, step);
+ auto step_size = lr_ * std::sqrt(bias_correction2) / bias_correction1;
+
+ p.addcdiv_(exp_avg, denom, -step_size);
+ }
+}
+
+void Adam::init_state() {
+ step_buffer_.clear();
+ exp_avg_buffer_.clear();
+ exp_avg_sq_buffer_.clear();
+}
+
+} // namespace autograd