| #include <c10/core/TensorOptions.h> |
| #include <test/cpp/jit/test_base.h> |
| #include <torch/csrc/autograd/generated/variable_factories.h> |
| #include <torch/csrc/jit/api/module.h> |
| #include <torch/csrc/jit/mobile/import.h> |
| #include <torch/csrc/jit/mobile/module.h> |
| #include <torch/csrc/jit/serialization/import.h> |
| #include <torch/custom_class.h> |
| #include <torch/torch.h> |
| |
| // Tests go in torch::jit |
| namespace torch { |
| namespace jit { |
| |
| void testLiteInterpreterUpsampleNearest2d() { |
| Module m("m"); |
| m.define(R"( |
| def forward(self, input: Tensor, scale:float): |
| return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale)) |
| )"); |
| |
| std::vector<IValue> inputs; |
| inputs.emplace_back(torch::rand({1, 3, 128, 128})); |
| inputs.emplace_back(at::Scalar(2.0)); |
| auto ref = m.forward(inputs); |
| |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| IValue res; |
| res = bc.forward(inputs); |
| |
| auto resd = res.toTensor(); |
| auto refd = ref.toTensor(); |
| ASSERT_TRUE(resd.equal(refd)); |
| } |
| |
| void testLiteInterpreterAdd() { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| // TODO: support default param val, which was pushed in |
| // function schema's checkAndNormalizeInputs() |
| // m.define(R"( |
| // def add_it(self, x, b : int = 4): |
| // return self.foo + x + b |
| // )"); |
| m.define(R"( |
| def add_it(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| |
| std::vector<IValue> inputs; |
| auto minput = 5 * torch::ones({}); |
| inputs.emplace_back(minput); |
| auto ref = m.run_method("add_it", minput); |
| |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| IValue res; |
| for (int i = 0; i < 3; ++i) { |
| auto bcinputs = inputs; |
| res = bc.run_method("add_it", bcinputs); |
| } |
| |
| auto resd = res.toTensor().item<float>(); |
| auto refd = ref.toTensor().item<float>(); |
| AT_ASSERT(resd == refd); |
| } |
| |
| void testLiteInterpreterConv() { |
| auto s = std::getenv("PYTORCH_TEST_WITH_TSAN"); |
| if (s && strcmp(s, "1") == 0) |
| return; |
| |
| std::vector<torch::jit::IValue> inputs; |
| |
| Module m("m"); |
| m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false); |
| m.register_parameter("bias", torch::ones({20}), false); |
| m.define(R"( |
| def forward(self, input): |
| return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True) |
| )"); |
| |
| inputs.push_back(torch::ones({1, 1, 28, 28})); |
| |
| auto outputref = m.forward(inputs).toTensor(); |
| |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| IValue res; |
| for (int i = 0; i < 3; ++i) { |
| res = bc.run_method("forward", inputs); |
| } |
| auto output = res.toTensor(); |
| AT_ASSERT(outputref.dim() == output.dim()); |
| AT_ASSERT( |
| outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>()); |
| } |
| |
| void testLiteInterpreterInline() { |
| Module m("m"); |
| m.define(R"JIT( |
| def foo1(self, x): |
| return x + 1 |
| |
| def foo2(self, x): |
| return self.foo1(x) + 2 |
| |
| def foo3(self, x): |
| return self.foo2(x) + 3 |
| )JIT"); |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
| auto output = bc.run_method("foo3", inputs); |
| AT_ASSERT(output.toTensor().item<float>() == 7.0); |
| } |
| |
| void testLiteInterpreterTuple() { |
| Module m("m"); |
| m.define(R"JIT( |
| def foo(self, x): |
| return (1, 2, x + 3) |
| |
| def forward(self, x): |
| tuple = self.foo(x) |
| return tuple |
| )JIT"); |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
| auto output = bc.run_method("forward", inputs); |
| AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2); |
| } |
| |
| void testLiteInterpreterDict() { |
| Module m("m"); |
| m.define(R"JIT( |
| def foo(self, x): |
| return {"result": x + 1} |
| |
| def forward(self, x): |
| d = self.foo(x) |
| return d |
| )JIT"); |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
| auto output = bc.run_method("forward", inputs); |
| AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2); |
| } |
| |
| void testLiteInterpreterPrimOverload() { |
| /* |
| // temporarily disabled |
| script::Module m("m"); |
| m.define(R"JIT( |
| def forward(self, x): |
| result = [1, 2] |
| result.append(3) |
| return result |
| )JIT"); |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<torch::jit::IValue> inputs({torch::ones({})}); |
| auto output = bc.run_method("forward", inputs); |
| AT_ASSERT(output.toIntList()[2] == 3); |
| */ |
| } |
| |
| void testLiteInterpreterPrim() { |
| Module m("m"); |
| m.define(R"JIT( |
| def forward(self, x): |
| return int(x) |
| )JIT"); |
| |
| std::vector<IValue> inputs; |
| auto minput = 3.5 * torch::ones({}); |
| inputs.emplace_back(minput); |
| auto ref = m.run_method("forward", minput); |
| |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| IValue res; |
| for (int i = 0; i < 3; ++i) { |
| auto bcinputs = inputs; |
| res = bc.run_method("forward", bcinputs); |
| } |
| |
| auto resi = res.toInt(); |
| auto refi = ref.toInt(); |
| AT_ASSERT(resi == refi); |
| } |
| |
| void testLiteInterpreterLoadOrigJit() { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| m.define(R"( |
| def forward(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| std::stringstream ss; |
| m.save(ss); |
| ASSERT_THROWS_WITH(_load_for_mobile(ss), "file not found"); |
| } |
| |
| void testLiteInterpreterWrongMethodName() { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| m.define(R"( |
| def add(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<IValue> inputs; |
| auto minput = 5 * torch::ones({}); |
| inputs.emplace_back(minput); |
| ASSERT_THROWS_WITH(bc.run_method("forward", inputs), "is not defined"); |
| } |
| |
| void testLiteInterpreterParams() { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({1}, at::requires_grad()), false); |
| m.define(R"( |
| def forward(self, x): |
| b = 1.0 |
| return self.foo * x + b |
| )"); |
| double learning_rate = 0.1, momentum = 0.1; |
| int n_epoc = 10; |
| // init: y = x + 1; |
| // target: y = 2 x + 1 |
| std::vector<std::pair<Tensor, Tensor>> trainData{ |
| {1 * torch::ones({1}), 3 * torch::ones({1})}, |
| }; |
| // Reference: Full jit |
| std::stringstream ms; |
| m.save(ms); |
| auto mm = load(ms); |
| // mm.train(); |
| std::vector<::at::Tensor> parameters; |
| for (auto parameter : mm.parameters()) { |
| parameters.emplace_back(parameter); |
| } |
| ::torch::optim::SGD optimizer( |
| parameters, ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
| for (int epoc = 0; epoc < n_epoc; ++epoc) { |
| for (auto& data : trainData) { |
| auto source = data.first, targets = data.second; |
| optimizer.zero_grad(); |
| std::vector<IValue> train_inputs{source}; |
| auto output = mm.forward(train_inputs).toTensor(); |
| auto loss = ::torch::l1_loss(output, targets); |
| loss.backward(); |
| optimizer.step(); |
| } |
| } |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| std::vector<::at::Tensor> bc_parameters = bc.parameters(); |
| ::torch::optim::SGD bc_optimizer( |
| bc_parameters, |
| ::torch::optim::SGDOptions(learning_rate).momentum(momentum)); |
| for (int epoc = 0; epoc < n_epoc; ++epoc) { |
| for (auto& data : trainData) { |
| auto source = data.first, targets = data.second; |
| bc_optimizer.zero_grad(); |
| std::vector<IValue> train_inputs{source}; |
| auto output = bc.forward(train_inputs).toTensor(); |
| auto loss = ::torch::l1_loss(output, targets); |
| loss.backward(); |
| bc_optimizer.step(); |
| } |
| } |
| AT_ASSERT(parameters[0].item<float>() == bc_parameters[0].item<float>()); |
| } |
| |
| void testLiteInterpreterSetState() { |
| Module m("m"); |
| m.register_parameter("foo", torch::ones({}), false); |
| m.define(R"( |
| def __getstate__(self): |
| return self.foo + self.foo |
| def __setstate__(self, a): |
| self.foo = a |
| def forward(self, x): |
| b = 4 |
| return self.foo + x + b |
| )"); |
| |
| std::vector<IValue> inputs; |
| auto minput = 5 * torch::ones({}); |
| inputs.emplace_back(minput); |
| |
| std::stringstream ms; |
| m.save(ms); |
| auto loaded_m = load(ms); |
| auto ref = loaded_m.run_method("forward", minput); |
| |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| IValue res; |
| for (int i = 0; i < 3; ++i) { |
| auto bcinputs = inputs; |
| res = bc.run_method("forward", bcinputs); |
| } |
| |
| auto resd = res.toTensor().item<float>(); |
| auto refd = ref.toTensor().item<float>(); |
| AT_ASSERT(resd == refd); |
| } |
| |
| class TorchBindLiteInterpreterTestStruct |
| : public torch::jit::CustomClassHolder { |
| public: |
| std::string get(at::Tensor t) { |
| std::stringstream ss; |
| ss << "Hello! Your tensor has "; |
| ss << t.numel(); |
| ss << " elements!"; |
| return ss.str(); |
| } |
| }; |
| |
| void testLiteInterpreterBuiltinFunction() { |
| script::Module m("m"); |
| auto custom_class_obj = |
| make_custom_class<TorchBindLiteInterpreterTestStruct>(); |
| m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj); |
| m.define(R"( |
| def forward(self, x) -> str: |
| return self.my_obj.get(x) |
| )"); |
| |
| std::stringstream ss; |
| m._save_for_mobile(ss); |
| mobile::Module bc = _load_for_mobile(ss); |
| auto res = |
| bc.run_method("forward", std::vector<IValue>{torch::zeros({3, 4})}); |
| auto str = res.toStringRef(); |
| std::string expected = "Hello! Your tensor has 12 elements!"; |
| AT_ASSERT(str == expected); |
| } |
| |
| namespace { |
| static auto reg = |
| torch::jit::class_<TorchBindLiteInterpreterTestStruct>( |
| "_TorchScriptTesting", |
| "_LiteInterpreterTest") |
| .def("get", &TorchBindLiteInterpreterTestStruct::get) |
| .def_pickle( |
| // __getattr__ |
| [](const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>& |
| self) -> int64_t { return 0; }, |
| // __setattr__ |
| [](int64_t state) { |
| return c10::make_intrusive<TorchBindLiteInterpreterTestStruct>(); |
| }); |
| |
| } // namespace |
| |
| } // namespace jit |
| } // namespace torch |