blob: ebe7eb150c6d50c9bbb5af3bfebbaff3947a18ec [file] [log] [blame]
#include <test/cpp/jit/test_utils.h>
#include <gtest/gtest.h>
#include <c10/core/TensorOptions.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/api/module.h>
#include <torch/csrc/jit/frontend/resolver.h>
#include <torch/csrc/jit/mobile/import.h>
#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/serialization/export.h>
#include <torch/csrc/jit/serialization/import.h>
#include <torch/custom_class.h>
#include <torch/torch.h>
#include <unordered_set>
// Tests go in torch::jit
namespace torch {
namespace jit {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, UpsampleNearest2d) {
Module m("m");
m.define(R"(
def forward(self, input: Tensor, scale:float):
return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
)");
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
inputs.emplace_back(torch::rand({1, 3, 128, 128}));
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
inputs.emplace_back(at::Scalar(2.0));
auto ref = m.forward(inputs);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
res = bc.forward(inputs);
auto resd = res.toTensor();
auto refd = ref.toTensor();
ASSERT_TRUE(resd.equal(refd));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, CheckAttrAccess) {
Module m("m");
m.register_attribute("mobile_optimized", BoolType::get(), true);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
AT_ASSERT(mobile_optimized);
m.setattr("mobile_optimized", false);
ss = std::stringstream();
m._save_for_mobile(ss);
bc = _load_for_mobile(ss);
mobile_optimized = bc.attr("mobile_optimized", false).toBool();
AT_ASSERT(!mobile_optimized);
}
TEST(LiteInterpreterTest, MethodInvocation) { // NOLINT (use =delete in gtest)
const std::vector<std::string> test_programs{
// test invoking a method with default parameter
R"(
def test_func(self, x, b : int = 4):
return self.foo + x + b
)",
// inner method call with default parameter (gets inlined)
R"(
def add_with_default_arg(self, x, b : int = 4):
return self.foo + x + b
def test_func(self, x):
return self.add_with_default_arg(x) # invoke method w/ default arg
)",
// simple method call
R"(
def test_func(self, x):
b = 4
return self.foo + x + b
)",
};
for (const auto& test_program : test_programs) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(test_program);
const int fortyTwo = 42; // (keep linter happy)
auto minput = fortyTwo * torch::ones({});
auto ref = m.run_method("test_func", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
const auto& test_func = bc.get_method("test_func");
IValue res;
for (int i = 0; i < 3; ++i) {
res = test_func({minput});
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, Conv) {
auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
if (s && strcmp(s, "1") == 0)
return;
std::vector<torch::jit::IValue> inputs;
Module m("m");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
m.register_parameter("bias", torch::ones({20}), false);
m.define(R"(
def forward(self, input):
return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
)");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
inputs.push_back(torch::ones({1, 1, 28, 28}));
auto outputref = m.forward(inputs).toTensor();
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
res = bc.get_method("forward")(inputs);
}
auto output = res.toTensor();
AT_ASSERT(outputref.dim() == output.dim());
AT_ASSERT(
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, Inline) {
Module m("m");
m.define(R"JIT(
def foo1(self, x):
return x + 1
def foo2(self, x):
return self.foo1(x) + 2
def foo3(self, x):
return self.foo2(x) + 3
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.get_method("foo3")(inputs);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
AT_ASSERT(output.toTensor().item<float>() == 7.0);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, Tuple) {
Module m("m");
m.define(R"JIT(
def foo(self, x):
return (1, 2, x + 3)
def forward(self, x):
tuple = self.foo(x)
return tuple
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.get_method("forward")(inputs);
AT_ASSERT(output.toTuple()->elements()[1].toInt() == 2);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, Dict) {
Module m("m");
m.define(R"JIT(
def foo(self, x):
return {"result": x + 1}
def forward(self, x):
d = self.foo(x)
return d
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.get_method("forward")(inputs);
AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, PrimOverload) {
/*
// temporarily disabled
script::Module m("m");
m.define(R"JIT(
def forward(self, x):
result = [1, 2]
result.append(3)
return result
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<torch::jit::IValue> inputs({torch::ones({})});
auto output = bc.get_method("forward")(inputs);
AT_ASSERT(output.toIntList()[2] == 3);
*/
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, Prim) {
Module m("m");
m.define(R"JIT(
def forward(self, x):
return int(x)
)JIT");
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto minput = 3.5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.run_method("forward", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto bcinputs = inputs;
res = bc.get_method("forward")(bcinputs);
}
auto resi = res.toInt();
auto refi = ref.toInt();
AT_ASSERT(resi == refi);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, PrimScalar) {
Module m("m");
m.define(R"JIT(
def forward(self, x):
return int(x.item())
)JIT");
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto minput = 3.5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.run_method("forward", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto bcinputs = inputs;
res = bc.get_method("forward")(bcinputs);
}
auto resi = res.toInt();
auto refi = ref.toInt();
AT_ASSERT(resi == refi);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, LoadOrigJit) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def forward(self, x):
b = 4
return self.foo + x + b
)");
std::stringstream ss;
m.save(ss);
ASSERT_THROWS_WITH_MESSAGE(_load_for_mobile(ss), "file not found");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, WrongMethodName) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add(self, x):
b = 4
return self.foo + x + b
)");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto minput = 5 * torch::ones({});
inputs.emplace_back(minput);
ASSERT_THROWS_WITH_MESSAGE(
bc.get_method("forward")(inputs), "is not defined");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, SetState) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def __getstate__(self):
return self.foo + self.foo
def __setstate__(self, a):
self.foo = a
def forward(self, x):
b = 4
return self.foo + x + b
)");
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto minput = 5 * torch::ones({});
inputs.emplace_back(minput);
std::stringstream ms;
m.save(ms);
auto loaded_m = load(ms);
auto ref = loaded_m.run_method("forward", minput);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto bcinputs = inputs;
res = bc.get_method("forward")(bcinputs);
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
class TorchBindLiteInterpreterTestStruct
: public torch::jit::CustomClassHolder {
public:
std::string get(at::Tensor t) {
std::stringstream ss;
ss << "Hello! Your tensor has ";
ss << t.numel();
ss << " elements!";
return ss.str();
}
};
namespace {
struct ClassNamespaceValue : public SugaredValue {
explicit ClassNamespaceValue(c10::QualifiedName name)
: basename_(std::move(name)) {}
std::shared_ptr<SugaredValue> attr(
const SourceRange& loc,
Function& m,
const std::string& name) override {
const auto fullName = c10::QualifiedName(basename_, name);
// Check to see if it is a custom class.
if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
return std::make_shared<ClassValue>(custom_class);
}
// If it's not a custom class, assume it's another namespace
// NOLINTNEXTLINE(performance-move-const-arg)
return std::make_shared<ClassNamespaceValue>(std::move(fullName));
}
std::string kind() const override {
return "Class Namespace";
}
private:
c10::QualifiedName basename_;
};
struct TestModuleResolver : public Resolver {
std::shared_ptr<SugaredValue> resolveValue(
const std::string& name,
Function& m,
const SourceRange& loc) override {
if (name == "torch") {
return std::make_shared<BuiltinModule>("aten");
} else if (name == "__torch__") {
return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
}
return nullptr;
}
TypePtr resolveType(const std::string& name, const SourceRange& loc)
override {
return nullptr;
}
};
} // namespace
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, BuiltinClass) {
script::Module m("m");
auto cls = getCustomClass(
"__torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest");
TORCH_INTERNAL_ASSERT(cls);
c10::intrusive_ptr<torch::CustomClassHolder> obj_holder;
m.register_attribute("my_obj", cls, IValue::make_capsule(obj_holder));
m.register_parameter("foo", torch::ones({}), false);
m.define(
R"(
def __getstate__(self):
return 1
def __setstate__(self, a):
self.my_obj = __torch__.torch.classes._TorchScriptTesting._LiteInterpreterTest()
def forward(self, x) -> str:
return self.my_obj.get(x)
)",
std::make_shared<TestModuleResolver>());
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
auto res =
bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
const auto& str = res.toStringRef();
std::string expected = "Hello! Your tensor has 12 elements!";
AT_ASSERT(str == expected);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, BuiltinFunction) {
script::Module m("m");
auto custom_class_obj =
make_custom_class<TorchBindLiteInterpreterTestStruct>();
m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
m.define(R"(
def forward(self, x) -> str:
return self.my_obj.get(x)
)");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
auto res =
bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
// NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
auto str = res.toStringRef();
std::string expected = "Hello! Your tensor has 12 elements!";
AT_ASSERT(str == expected);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, ModuleInfoBasic) {
Module m("M");
m.define(R"JIT(
def forward(self, x):
return 2 * x
)JIT");
std::stringstream ss;
m._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() && module_info != "<no module info>") {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
std::unordered_set<std::string> expected_result({"top(M).forward"});
AT_ASSERT(module_debug_info_set == expected_result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, NotSaveModuleInfo) {
Module m("M");
m.define(R"JIT(
def forward(self, x):
return x + 5
)JIT");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
AT_ASSERT(module_info.empty() || module_info == "<no module info>");
++pc;
} catch (const std::exception& e) {
break;
}
}
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, OneSubmoduleModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return 2 * x + 5
)JIT");
Module b("B");
b.register_module("A0", a);
b.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + 1
)JIT");
std::stringstream ss;
b._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() && module_info != "<no module info>") {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
std::unordered_set<std::string> expected_result(
{"top(B).forward", "top(B).A0(A).forward"});
AT_ASSERT(module_debug_info_set == expected_result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, TwoSubmodulesModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B");
b.define(R"JIT(
def forward(self, x):
return x + 2
)JIT");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + self.B0.forward(x)
)JIT");
std::stringstream ss;
c._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() && module_info != "<no module info>") {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
std::unordered_set<std::string> expected_result(
{"top(C).forward", "top(C).A0(A).forward", "top(C).B0(B).forward"});
AT_ASSERT(module_debug_info_set == expected_result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, SequentialModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B");
b.define(R"JIT(
def forward(self, x):
return x + 2
)JIT");
Module c("C");
c.register_module("A0", a);
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.A0.forward(self.B0.forward(x))
)JIT");
std::stringstream ss;
c._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() && module_info != "<no module info>") {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
// class A(nn.Module):
// def __init__(self):
// super(A, self).__init__()
// def forward(self, x):
// return x + 1
// class B(nn.Module):
// def __init__(self):
// super(B, self).__init__()
// def forward(self, x):
// return x + 2
// class C(nn.Module):
// def __init__(self):
// super(C, self).__init__()
// self.A0 = A()
// self.B0 = B()
// def forward(self, x):
// return self.A0.forward(self.B0.forward(x))
std::unordered_set<std::string> expected_result(
{"top(C).A0(A).forward", "top(C).B0(B).forward"});
AT_ASSERT(module_debug_info_set == expected_result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, HierarchyModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return x + 1
)JIT");
Module b("B");
b.register_module("A0", a);
b.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + 1
)JIT");
Module c("C");
c.register_module("B0", b);
c.define(R"JIT(
def forward(self, x):
return self.B0.forward(x) + 1
)JIT");
std::stringstream ss;
c._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() && module_info != "<no module info>") {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
// There are 3 module information strings here.
// "top(C).forward": for the add operator in top.
// "top(C).B0(B).forward": for the add operator in B0.
// "top(C).B0(B).forward.A0(A).forward": for the add operator in A0.
std::unordered_set<std::string> expected_result(
{"top(C).forward",
"top(C).B0(B).forward",
"top(C).B0(B).forward.A0(A).forward"});
AT_ASSERT(module_debug_info_set == expected_result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, DuplicatedClassTypeModuleInfo) {
Module a("A");
a.define(R"JIT(
def forward(self, x):
return x + 5
)JIT");
Module b("B");
b.register_module("A0", a);
b.register_module("A1", a);
b.define(R"JIT(
def forward(self, x):
return self.A0.forward(x) + self.A1.forward(x)
)JIT");
std::stringstream ss;
b._save_for_mobile(ss, {}, true);
mobile::Module bc = _load_for_mobile(ss);
std::unordered_set<std::string> module_debug_info_set;
size_t pc = 0;
while (true) {
try {
std::string module_info = bc.get_forward_method_debug_info(pc);
if (!module_info.empty() && module_info != "<no module info>") {
module_debug_info_set.insert(module_info);
}
++pc;
} catch (const std::exception& e) {
break;
}
}
// class A(nn.Module):
// def __init__(self):
// super(A, self).__init__()
// def forward(self, x):
// return x + 5
// class B(nn.Module):
// def __init__(self):
// super(B, self).__init__()
// self.A0 = A()
// self.A1 = A()
// def forward(self, x):
// return self.A0.forward(x) + self.A1.forward(x)
// There are 3 module information strings here.
// "top(B).forward": for the add operator in top.
// "top(B).A0(A).forward": for the add operator in A0.
// "top(B).A1(A).forward": for the add operator in A1.
std::unordered_set<std::string> expected_result(
{"top(B).forward", "top(B).A0(A).forward", "top(B).A1(A).forward"});
AT_ASSERT(module_debug_info_set == expected_result);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, Eval) {
std::vector<torch::jit::IValue> inputs;
Module m("m");
m.define(R"(
def __init__(self, x):
self.training = True
def forward(self, input):
return torch.dropout(input, 1.0, self.training)
)");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
inputs.push_back(torch::ones({1, 1, 28, 28}));
m.eval();
auto outputref = m.forward(inputs).toTensor();
// save m in training mode to make sure that mobile eval() will correctly
// change back to eval mode
m.train();
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
bc.eval();
IValue res;
for (int i = 0; i < 3; ++i) {
res = bc.get_method("forward")(inputs);
}
auto output = res.toTensor();
AT_ASSERT(outputref.dim() == output.dim());
AT_ASSERT(
outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, FindWrongMethodName) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add(self, x):
b = 4
return self.foo + x + b
)");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, FindAndRunMethod) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add_it(self, x):
b = 4
return self.foo + x + b
)");
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto minput = 5 * torch::ones({});
inputs.emplace_back(minput);
auto ref = m.get_method("add_it")(inputs);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res;
for (int i = 0; i < 3; ++i) {
auto bcinputs = inputs;
auto method = bc.find_method("add_it");
AT_ASSERT(method != c10::nullopt);
res = (*method)(std::move(bcinputs));
}
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, RunMethodVariadic) {
Module m("m");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def add_three(self, x, y):
return self.foo + x + y
)");
std::vector<IValue> inputs;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto inputx = 5 * torch::ones({});
auto inputy = 4 * torch::ones({});
auto ref = m.run_method("add_three", inputx, inputy);
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
IValue res = bc.run_method("add_three", inputx, inputy);
auto resd = res.toTensor().item<float>();
auto refd = ref.toTensor().item<float>();
AT_ASSERT(resd == refd);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, DuplicateSetState) {
Module m("M");
m.register_parameter("foo", torch::ones({}), false);
m.define(R"(
def __getstate__(self):
return self.foo + self.foo
def __setstate__(self, a):
self.foo = a
def forward(self, x):
b = 4
return self.foo + x + b
)");
Module b("B");
b.register_module("M0", m);
b.register_module("M1", m);
b.define(R"(
def forward(self, x):
return self.M0.forward(x) + self.M1.forward(x)
)");
std::stringstream ss;
m._save_for_mobile(ss);
mobile::Module bc = _load_for_mobile(ss);
const auto methods = bc.get_methods();
const size_t expected_n = 3;
ASSERT_EQ(methods.size(), expected_n);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, ExtraFiles) {
const auto script = R"JIT(
def forward(self):
x = torch.rand(5, 5)
x = x.mm(x)
return x
)JIT";
auto module =
std::make_shared<Module>("Module", std::make_shared<CompilationUnit>());
module->define(script);
std::ostringstream oss;
std::unordered_map<std::string, std::string> extra_files;
extra_files["metadata.json"] = "abc";
extra_files["mobile_info.json"] = "{\"key\": 23}";
module->_save_for_mobile(oss, extra_files);
std::istringstream iss(oss.str());
caffe2::serialize::IStreamAdapter adapter{&iss};
std::unordered_map<std::string, std::string> loaded_extra_files;
loaded_extra_files["metadata.json"] = "";
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
loaded_extra_files.clear();
std::vector<std::string> all_files =
caffe2::serialize::PyTorchStreamReader(&iss).getAllRecords();
for (auto& file_name : all_files) {
if (file_name.find("extra/") == 0) {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
loaded_extra_files[file_name.substr(6)] = "";
}
}
torch::jit::_load_for_mobile(iss, torch::kCPU, loaded_extra_files);
ASSERT_EQ(loaded_extra_files["metadata.json"], "abc");
ASSERT_EQ(loaded_extra_files["mobile_info.json"], "{\"key\": 23}");
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
TEST(LiteInterpreterTest, OpNameExportFetchRootOperators) {
torch::jit::Module m("m");
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
m.register_parameter("bias", torch::ones({20}), false);
m.define(R"(
def forward(self, input):
x1 = torch.zeros(2, 2)
x2 = torch.empty_like(torch.empty(2, 2))
x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
return (x1, x2, x3)
)");
m.eval();
std::stringstream ss;
m._save_for_mobile(ss);
torch::jit::mobile::Module ptl_model = torch::jit::_load_for_mobile(ss);
std::set<std::string> operator_names =
torch::jit::mobile::_export_operator_list(ptl_model);
std::set<std::string> expected_operator_names = {
"aten::_convolution",
"aten::empty.memory_format",
"aten::empty_like",
"aten::zeros",
};
EXPECT_EQ(operator_names, expected_operator_names)
<< "Expected the root operator lists to be the same";
}
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
static auto reg =
torch::class_<TorchBindLiteInterpreterTestStruct>(
"_TorchScriptTesting",
"_LiteInterpreterTest")
.def(torch::init<>())
.def("get", &TorchBindLiteInterpreterTestStruct::get)
.def_pickle(
// __getattr__
[](const c10::intrusive_ptr<TorchBindLiteInterpreterTestStruct>&
self) -> int64_t { return 0; },
// __setattr__
[](int64_t state) {
return c10::make_intrusive<TorchBindLiteInterpreterTestStruct>();
});
} // namespace
} // namespace jit
} // namespace torch