| #include <gtest/gtest.h> |
| |
| #include <c10/core/TensorOptions.h> |
| #include <torch/csrc/autograd/generated/variable_factories.h> |
| #include <torch/csrc/jit/api/module.h> |
| #include <torch/csrc/jit/mobile/export_data.h> |
| #include <torch/csrc/jit/mobile/import.h> |
| #include <torch/csrc/jit/mobile/import_data.h> |
| #include <torch/csrc/jit/mobile/module.h> |
| #include <torch/csrc/jit/mobile/optim/sgd.h> |
| #include <torch/csrc/jit/mobile/sequential.h> |
| #include <torch/csrc/jit/serialization/import.h> |
| #include <torch/data/dataloader.h> |
| #include <torch/torch.h> |
| |
| // Tests go in torch::jit |
| namespace torch { |
| namespace jit { |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| TEST(LiteTrainerTest, Params) { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); |
| m.define(R"( |
| def forward(self, x): |
| b = 1.0 |
| return self.foo * x + b |
| )"); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| double learning_rate = 0.1, momentum = 0.1; |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| int n_epoc = 10; |
| // init: y = x + 1; |
| // target: y = 2 x + 1 |
| std::vector<std::pair<Tensor, Tensor>> trainData{ |
| {1 * torch::ones({1}), 3 * torch::ones({1})}, |
| }; |
| // Reference: Full jit |
| std::stringstream ms; |
| m.save(ms); |
| auto mm = load(ms); |
| // mm.train(); |
| std::vector<::at::Tensor> parameters; |
| for (auto parameter : mm.parameters()) { |
| parameters.emplace_back(parameter); |
| } |
| ::torch::optim::SGD optimizer( |
| parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
| for (int epoc = 0; epoc < n_epoc; ++epoc) { |
| for (auto& data : trainData) { |
| auto source = data.first, targets = data.second; |
| optimizer.zero_grad(); |
| std::vector<IValue> train_inputs{source}; |
| auto output = mm.forward(train_inputs).toTensor(); |
| auto loss = ::torch::l1_loss(output, targets); |
| loss.backward(); |
| optimizer.step(); |
| } |
| } |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<::at::Tensor> bc_parameters = bc.parameters(); |
| ::torch::optim::SGD bc_optimizer( |
| bc_parameters, |
| ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
| for (int epoc = 0; epoc < n_epoc; ++epoc) { |
| for (auto& data : trainData) { |
| auto source = data.first, targets = data.second; |
| bc_optimizer.zero_grad(); |
| std::vector<IValue> train_inputs{source}; |
| auto output = bc.forward(train_inputs).toTensor(); |
| auto loss = ::torch::l1_loss(output, targets); |
| loss.backward(); |
| bc_optimizer.step(); |
| } |
| } |
| AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>()); |
| } |
| |
| // TODO Renable these tests after parameters are correctly loaded on mobile |
| /* |
| TEST(MobileTest, NamedParameters) { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| m.define(R"( |
| def add_it(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| Module child("m2"); |
| child.register_parameter("foo", 4 * torch::ones({}), false); |
| child.register_parameter("bar", 4 * torch::ones({}), false); |
| m.register_module("child1", child); |
| m.register_module("child2", child.clone()); |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| |
| auto full_params = m.named_parameters(); |
| auto mobile_params = bc.named_parameters(); |
| AT_ASSERT(full_params.size() == mobile_params.size()); |
| for (const auto& e : full_params) { |
| AT_ASSERT(e.value.item().toInt() == |
| mobile_params[e.name].item().toInt()); |
| } |
| } |
| |
| TEST(MobileTest, SaveLoadData) { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| m.define(R"( |
| def add_it(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| Module child("m2"); |
| child.register_parameter("foo", 4 * torch::ones({}), false); |
| child.register_parameter("bar", 3 * torch::ones({}), false); |
| m.register_module("child1", child); |
| m.register_module("child2", child.clone()); |
| auto full_params = m.named_parameters(); |
| |
| std::stringstream ss; |
| std::stringstream ss_data; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| |
| mobile::_save_data(bc, ss_data); |
| auto mobile_params = mobile::_load_data(ss_data).named_parameters(); |
| AT_ASSERT(full_params.size() == mobile_params.size()); |
| for (const auto& e : full_params) { |
| AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>()); |
| } |
| } |
| |
| TEST(MobileTest, SaveLoadParameters) { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| m.define(R"( |
| def add_it(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| Module child("m2"); |
| child.register_parameter("foo", 4 * torch::ones({}), false); |
| child.register_parameter("bar", 3 * torch::ones({}), false); |
| m.register_module("child1", child); |
| m.register_module("child2", child.clone()); |
| auto full_params = m.named_parameters(); |
| std::stringstream ss; |
| std::stringstream ss_data; |
| m._save_for_mobile(ss); |
| |
| // load mobile module, save mobile named parameters |
| mobile::Module bc = _load_for_mobile(ss); |
| _save_parameters(bc.named_parameters(), ss_data); |
| |
| // load back the named parameters, compare to full-jit Module's |
| auto mobile_params = _load_parameters(ss_data); |
| AT_ASSERT(full_params.size() == mobile_params.size()); |
| for (const auto& e : full_params) { |
| AT_ASSERT(e.value.item<int>() == mobile_params[e.name].item<int>()); |
| } |
| } |
| */ |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| TEST(MobileTest, SaveLoadParametersEmpty) { |
| Module m("m"); |
| m.define(R"( |
| def add_it(self, x): |
| b = 4 |
| return x + b |
| )"); |
| Module child("m2"); |
| m.register_module("child1", child); |
| m.register_module("child2", child.clone()); |
| std::stringstream ss; |
| std::stringstream ss_data; |
| m._save_for_mobile(ss); |
| |
| // load mobile module, save mobile named parameters |
| mobile::Module bc = _load_for_mobile(ss); |
| _save_parameters(bc.named_parameters(), ss_data); |
| |
| // load back the named parameters, test is empty |
| auto mobile_params = _load_parameters(ss_data); |
| AT_ASSERT(mobile_params.size() == 0); |
| } |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| TEST(LiteTrainerTest, SGD) { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); |
| m.define(R"( |
| def forward(self, x): |
| b = 1.0 |
| return self.foo * x + b |
| )"); |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| double learning_rate = 0.1, momentum = 0.1; |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| int n_epoc = 10; |
| // init: y = x + 1; |
| // target: y = 2 x + 1 |
| std::vector<std::pair<Tensor, Tensor>> trainData{ |
| {1 * torch::ones({1}), 3 * torch::ones({1})}, |
| }; |
| // Reference: Full jit and torch::optim::SGD |
| std::stringstream ms; |
| m.save(ms); |
| auto mm = load(ms); |
| std::vector<::at::Tensor> parameters; |
| for (auto parameter : mm.parameters()) { |
| parameters.emplace_back(parameter); |
| } |
| ::torch::optim::SGD optimizer( |
| parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
| for (int epoc = 0; epoc < n_epoc; ++epoc) { |
| for (auto& data : trainData) { |
| auto source = data.first, targets = data.second; |
| optimizer.zero_grad(); |
| std::vector<IValue> train_inputs{source}; |
| auto output = mm.forward(train_inputs).toTensor(); |
| auto loss = ::torch::l1_loss(output, targets); |
| loss.backward(); |
| optimizer.step(); |
| } |
| } |
| // Test: lite interpreter and torch::jit::mobile::SGD |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<::at::Tensor> bc_parameters = bc.parameters(); |
| ::torch::jit::mobile::SGD bc_optimizer( |
| bc_parameters, |
| ::torch::jit::mobile::SGDOptions(learning_rate).momentum(momentum)); |
| for (int epoc = 0; epoc < n_epoc; ++epoc) { |
| for (auto& data : trainData) { |
| auto source = data.first, targets = data.second; |
| bc_optimizer.zero_grad(); |
| std::vector<IValue> train_inputs{source}; |
| auto output = bc.forward(train_inputs).toTensor(); |
| auto loss = ::torch::l1_loss(output, targets); |
| loss.backward(); |
| bc_optimizer.step(); |
| } |
| } |
| AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>()); |
| } |
| |
| namespace { |
| struct DummyDataset : torch::data::datasets::Dataset<DummyDataset, int> { |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| explicit DummyDataset(size_t size = 100) : size_(size) {} |
| |
| int get(size_t index) override { |
| // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions) |
| return 1 + index; |
| } |
| torch::optional<size_t> size() const override { |
| return size_; |
| } |
| |
| size_t size_; |
| }; |
| } // namespace |
| |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) |
| TEST(LiteTrainerTest, SequentialSampler) { |
| // test that sampler can be used with dataloader |
| const int kBatchSize = 10; |
| auto data_loader = |
| torch::data::make_data_loader<torch::data::samplers::SequentialSampler>( |
| // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers) |
| DummyDataset(25), |
| kBatchSize); |
| int i = 1; |
| for (const auto& batch : *data_loader) { |
| for (const auto& example : batch) { |
| AT_ASSERT(i == example); |
| i++; |
| } |
| } |
| } |
| |
| } // namespace jit |
| } // namespace torch |