| #include <gtest/gtest.h> |
| |
| #include <torch/nn/module.h> |
| #include <torch/nn/modules/batchnorm.h> |
| #include <torch/nn/modules/conv.h> |
| #include <torch/nn/modules/dropout.h> |
| #include <torch/nn/modules/embedding.h> |
| #include <torch/nn/modules/functional.h> |
| #include <torch/nn/modules/linear.h> |
| #include <torch/types.h> |
| #include <torch/utils.h> |
| |
| #include <test/cpp/api/support.h> |
| |
| using namespace torch::nn; |
| using namespace torch::test; |
| |
| class TestModel : public torch::nn::Module { |
| public: |
| TestModel() |
| : l1(register_module("l1", Linear(10, 3))), |
| l2(register_module("l2", Linear(3, 5))), |
| l3(register_module("l3", Linear(5, 100))) {} |
| |
| Linear l1, l2, l3; |
| }; |
| |
| class NestedModel : public torch::nn::Module { |
| public: |
| NestedModel() |
| : param_(register_parameter("param", torch::empty({3, 2, 21}))), |
| l1(register_module("l1", Linear(5, 20))), |
| t(register_module("test", std::make_shared<TestModel>())) {} |
| |
| torch::Tensor param_; |
| Linear l1; |
| std::shared_ptr<TestModel> t; |
| }; |
| |
| struct ModulesTest : torch::test::SeedingFixture {}; |
| |
| TEST_F(ModulesTest, Conv1d) { |
| Conv1d model(Conv1dOptions(3, 2, 3).stride(2)); |
| auto x = torch::randn({2, 3, 5}, torch::requires_grad()); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 3); |
| ASSERT_EQ(s.ndimension(), 0); |
| for (auto i = 0; i < 3; i++) { |
| ASSERT_EQ(y.size(i), 2); |
| } |
| |
| ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3); |
| } |
| |
| TEST_F(ModulesTest, Conv2dEven) { |
| Conv2d model(Conv2dOptions(3, 2, 3).stride(2)); |
| auto x = torch::randn({2, 3, 5, 5}, torch::requires_grad()); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 4); |
| ASSERT_EQ(s.ndimension(), 0); |
| for (auto i = 0; i < 4; i++) { |
| ASSERT_EQ(y.size(i), 2); |
| } |
| |
| ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 3); |
| } |
| |
| TEST_F(ModulesTest, Conv2dUneven) { |
| Conv2d model(Conv2dOptions(3, 2, {3, 2}).stride({2, 2})); |
| auto x = torch::randn({2, 3, 5, 4}, torch::requires_grad()); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 4); |
| ASSERT_EQ(s.ndimension(), 0); |
| for (auto i = 0; i < 4; i++) { |
| ASSERT_EQ(y.size(i), 2); |
| } |
| |
| ASSERT_EQ(model->weight.grad().numel(), 3 * 2 * 3 * 2); |
| } |
| |
| TEST_F(ModulesTest, Conv3d) { |
| Conv3d model(Conv3dOptions(3, 2, 3).stride(2)); |
| auto x = torch::randn({2, 3, 5, 5, 5}, torch::requires_grad()); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 5); |
| ASSERT_EQ(s.ndimension(), 0); |
| for (auto i = 0; i < 5; i++) { |
| ASSERT_EQ(y.size(i), 2); |
| } |
| |
| ASSERT_TRUE(model->weight.grad().numel() == 3 * 2 * 3 * 3 * 3); |
| } |
| |
| TEST_F(ModulesTest, Linear) { |
| Linear model(5, 2); |
| auto x = torch::randn({10, 5}, torch::requires_grad()); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 2); |
| ASSERT_EQ(s.ndimension(), 0); |
| ASSERT_EQ(y.size(0), 10); |
| ASSERT_EQ(y.size(1), 2); |
| |
| ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
| } |
| |
| TEST_F(ModulesTest, SimpleContainer) { |
| auto model = std::make_shared<SimpleContainer>(); |
| auto l1 = model->add(Linear(10, 3), "l1"); |
| auto l2 = model->add(Linear(3, 5), "l2"); |
| auto l3 = model->add(Linear(5, 100), "l3"); |
| |
| auto x = torch::randn({1000, 10}, torch::requires_grad()); |
| x = l1(x).clamp_min(0); |
| x = l2(x).clamp_min(0); |
| x = l3(x).clamp_min(0); |
| |
| x.backward(); |
| ASSERT_EQ(x.ndimension(), 2); |
| ASSERT_EQ(x.size(0), 1000); |
| ASSERT_EQ(x.size(1), 100); |
| ASSERT_EQ(x.min().item<float>(), 0); |
| } |
| |
| TEST_F(ModulesTest, EmbeddingBasic) { |
| const int64_t dict_size = 10; |
| Embedding model(dict_size, 2); |
| ASSERT_TRUE(model->named_parameters().contains("weight")); |
| ASSERT_EQ(model->weight.ndimension(), 2); |
| ASSERT_EQ(model->weight.size(0), dict_size); |
| ASSERT_EQ(model->weight.size(1), 2); |
| |
| // Cannot get gradients to change indices (input) - only for embedding |
| // params |
| auto x = torch::full({10}, dict_size - 1, torch::kInt64); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 2); |
| ASSERT_EQ(s.ndimension(), 0); |
| ASSERT_EQ(y.size(0), 10); |
| ASSERT_EQ(y.size(1), 2); |
| |
| ASSERT_EQ(model->weight.grad().numel(), 2 * dict_size); |
| } |
| |
| TEST_F(ModulesTest, EmbeddingList) { |
| Embedding model(6, 4); |
| auto x = torch::full({2, 3}, 5, torch::kInt64); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 3); |
| ASSERT_EQ(y.size(0), 2); |
| ASSERT_EQ(y.size(1), 3); |
| ASSERT_EQ(y.size(2), 4); |
| } |
| |
| TEST_F(ModulesTest, Dropout) { |
| Dropout dropout(0.5); |
| torch::Tensor x = torch::ones(100, torch::requires_grad()); |
| torch::Tensor y = dropout(x); |
| |
| y.backward(); |
| ASSERT_EQ(y.ndimension(), 1); |
| ASSERT_EQ(y.size(0), 100); |
| ASSERT_LT(y.sum().item<float>(), 130); // Probably |
| ASSERT_GT(y.sum().item<float>(), 70); // Probably |
| |
| dropout->eval(); |
| y = dropout(x); |
| ASSERT_EQ(y.sum().item<float>(), 100); |
| } |
| |
| TEST_F(ModulesTest, Parameters) { |
| auto model = std::make_shared<NestedModel>(); |
| auto parameters = model->named_parameters(); |
| ASSERT_EQ(parameters["param"].size(0), 3); |
| ASSERT_EQ(parameters["param"].size(1), 2); |
| ASSERT_EQ(parameters["param"].size(2), 21); |
| ASSERT_EQ(parameters["l1.bias"].size(0), 20); |
| ASSERT_EQ(parameters["l1.weight"].size(0), 20); |
| ASSERT_EQ(parameters["l1.weight"].size(1), 5); |
| ASSERT_EQ(parameters["test.l1.bias"].size(0), 3); |
| ASSERT_EQ(parameters["test.l1.weight"].size(0), 3); |
| ASSERT_EQ(parameters["test.l1.weight"].size(1), 10); |
| ASSERT_EQ(parameters["test.l2.bias"].size(0), 5); |
| ASSERT_EQ(parameters["test.l2.weight"].size(0), 5); |
| ASSERT_EQ(parameters["test.l2.weight"].size(1), 3); |
| ASSERT_EQ(parameters["test.l3.bias"].size(0), 100); |
| ASSERT_EQ(parameters["test.l3.weight"].size(0), 100); |
| ASSERT_EQ(parameters["test.l3.weight"].size(1), 5); |
| } |
| |
| TEST_F(ModulesTest, FunctionalCallsSuppliedFunction) { |
| bool was_called = false; |
| auto functional = Functional([&was_called](torch::Tensor input) { |
| was_called = true; |
| return input; |
| }); |
| auto output = functional(torch::ones(5, torch::requires_grad())); |
| ASSERT_TRUE(was_called); |
| ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad()))); |
| |
| was_called = false; |
| // Use the call operator overload here. |
| output = functional(torch::ones(5, torch::requires_grad())); |
| ASSERT_TRUE(was_called); |
| ASSERT_TRUE(output.equal(torch::ones(5, torch::requires_grad()))); |
| } |
| |
| TEST_F(ModulesTest, FunctionalWithTorchFunction) { |
| auto functional = Functional(torch::relu); |
| ASSERT_EQ(functional(torch::ones({})).item<float>(), 1); |
| ASSERT_EQ(functional(torch::ones({})).item<float>(), 1); |
| ASSERT_EQ(functional(torch::ones({}) * -1).item<float>(), 0); |
| } |
| |
| TEST_F(ModulesTest, FunctionalArgumentBinding) { |
| auto functional = |
| Functional(torch::elu, /*alpha=*/1, /*scale=*/0, /*input_scale=*/1); |
| ASSERT_EQ(functional(torch::ones({})).item<float>(), 0); |
| } |
| |
| TEST_F(ModulesTest, BatchNormStateful) { |
| BatchNorm bn(5); |
| |
| // Is stateful by default. |
| ASSERT_TRUE(bn->options.stateful()); |
| |
| ASSERT_TRUE(bn->running_mean.defined()); |
| ASSERT_EQ(bn->running_mean.dim(), 1); |
| ASSERT_EQ(bn->running_mean.size(0), 5); |
| |
| ASSERT_TRUE(bn->running_var.defined()); |
| ASSERT_EQ(bn->running_var.dim(), 1); |
| ASSERT_EQ(bn->running_var.size(0), 5); |
| |
| // Is affine by default. |
| ASSERT_TRUE(bn->options.affine()); |
| |
| ASSERT_TRUE(bn->weight.defined()); |
| ASSERT_EQ(bn->weight.dim(), 1); |
| ASSERT_EQ(bn->weight.size(0), 5); |
| |
| ASSERT_TRUE(bn->bias.defined()); |
| ASSERT_EQ(bn->bias.dim(), 1); |
| ASSERT_EQ(bn->bias.size(0), 5); |
| } |
| TEST_F(ModulesTest, BatchNormStateless) { |
| BatchNorm bn(BatchNormOptions(5).stateful(false).affine(false)); |
| |
| ASSERT_FALSE(bn->running_mean.defined()); |
| ASSERT_FALSE(bn->running_var.defined()); |
| ASSERT_FALSE(bn->weight.defined()); |
| ASSERT_FALSE(bn->bias.defined()); |
| |
| ASSERT_THROWS_WITH( |
| bn(torch::ones({2, 5})), |
| "Calling BatchNorm::forward is only permitted " |
| "when the 'stateful' option is true (was false). " |
| "Use BatchNorm::pure_forward instead."); |
| } |
| |
| TEST_F(ModulesTest, BatchNormPureForward) { |
| BatchNorm bn(BatchNormOptions(5).affine(false)); |
| bn->eval(); |
| |
| // Want to make sure we use the supplied values in `pure_forward` even if |
| // we are stateful. |
| auto input = torch::randn({2, 5}); |
| auto mean = torch::randn(5); |
| auto variance = torch::rand(5); |
| auto output = bn->pure_forward(input, mean, variance); |
| auto expected = (input - mean) / torch::sqrt(variance + bn->options.eps()); |
| ASSERT_TRUE(output.allclose(expected)); |
| } |
| |
| TEST_F(ModulesTest, Linear_CUDA) { |
| Linear model(5, 2); |
| model->to(torch::kCUDA); |
| auto x = |
| torch::randn({10, 5}, torch::device(torch::kCUDA).requires_grad(true)); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 2); |
| ASSERT_EQ(s.ndimension(), 0); |
| ASSERT_EQ(y.size(0), 10); |
| ASSERT_EQ(y.size(1), 2); |
| |
| ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
| } |
| |
| TEST_F(ModulesTest, Linear2_CUDA) { |
| Linear model(5, 2); |
| model->to(torch::kCUDA); |
| model->to(torch::kCPU); |
| auto x = torch::randn({10, 5}, torch::requires_grad()); |
| auto y = model(x); |
| torch::Tensor s = y.sum(); |
| |
| s.backward(); |
| ASSERT_EQ(y.ndimension(), 2); |
| ASSERT_EQ(s.ndimension(), 0); |
| ASSERT_EQ(y.size(0), 10); |
| ASSERT_EQ(y.size(1), 2); |
| |
| ASSERT_EQ(model->weight.grad().numel(), 2 * 5); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintLinear) { |
| ASSERT_EQ( |
| c10::str(Linear(3, 4)), "torch::nn::Linear(in=3, out=4, with_bias=true)"); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintConv) { |
| ASSERT_EQ( |
| c10::str(Conv1d(3, 4, 5)), |
| "torch::nn::Conv1d(input_channels=3, output_channels=4, kernel_size=5, stride=1)"); |
| ASSERT_EQ( |
| c10::str(Conv2d(3, 4, 5)), |
| "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[1, 1])"); |
| ASSERT_EQ( |
| c10::str(Conv2d(Conv2dOptions(3, 4, 5).stride(2))), |
| "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 5], stride=[2, 2])"); |
| |
| const auto options = Conv2dOptions(3, 4, torch::IntArrayRef{5, 6}).stride({1, 2}); |
| ASSERT_EQ( |
| c10::str(Conv2d(options)), |
| "torch::nn::Conv2d(input_channels=3, output_channels=4, kernel_size=[5, 6], stride=[1, 2])"); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintDropout) { |
| ASSERT_EQ(c10::str(Dropout(0.5)), "torch::nn::Dropout(rate=0.5)"); |
| ASSERT_EQ( |
| c10::str(FeatureDropout(0.5)), "torch::nn::FeatureDropout(rate=0.5)"); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintFunctional) { |
| ASSERT_EQ(c10::str(Functional(torch::relu)), "torch::nn::Functional()"); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintBatchNorm) { |
| ASSERT_EQ( |
| c10::str(BatchNorm( |
| BatchNormOptions(4).eps(0.5).momentum(0.1).affine(false).stateful( |
| true))), |
| "torch::nn::BatchNorm(features=4, eps=0.5, momentum=0.1, affine=false, stateful=true)"); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintEmbedding) { |
| ASSERT_EQ( |
| c10::str(Embedding(10, 2)), |
| "torch::nn::Embedding(count=10, dimension=2)"); |
| } |
| |
| TEST_F(ModulesTest, PrettyPrintNestedModel) { |
| struct InnerTestModule : torch::nn::Module { |
| InnerTestModule() |
| : torch::nn::Module("InnerTestModule"), |
| fc(register_module("fc", torch::nn::Linear(3, 4))), |
| table(register_module("table", torch::nn::Embedding(10, 2))) {} |
| |
| torch::nn::Linear fc; |
| torch::nn::Embedding table; |
| }; |
| |
| struct TestModule : torch::nn::Module { |
| TestModule() |
| : torch::nn::Module("TestModule"), |
| fc(register_module("fc", torch::nn::Linear(4, 5))), |
| table(register_module("table", torch::nn::Embedding(10, 2))), |
| inner(register_module("inner", std::make_shared<InnerTestModule>())) { |
| } |
| |
| torch::nn::Linear fc; |
| torch::nn::Embedding table; |
| std::shared_ptr<InnerTestModule> inner; |
| }; |
| |
| ASSERT_EQ( |
| c10::str(TestModule{}), |
| "TestModule(\n" |
| " (fc): torch::nn::Linear(in=4, out=5, with_bias=true)\n" |
| " (table): torch::nn::Embedding(count=10, dimension=2)\n" |
| " (inner): InnerTestModule(\n" |
| " (fc): torch::nn::Linear(in=3, out=4, with_bias=true)\n" |
| " (table): torch::nn::Embedding(count=10, dimension=2)\n" |
| " )\n" |
| ")"); |
| } |