Refactor saving jit::Module to mobile .pt in 2 steps: (#66494)
Summary:
1. is to convert Function -> mobile::Function
2. is to serialize mobile::Function
This also opens opportunity to create mobile::Module without saving/reloading
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66494
Reviewed By: zhxchen17
Differential Revision: D32293022
Pulled By: qihqi
fbshipit-source-id: 29b43d47ff86071d5e2f9d6ca4dba4445711ce3d
diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt
index b3def38..5c40664 100644
--- a/test/cpp/jit/CMakeLists.txt
+++ b/test/cpp/jit/CMakeLists.txt
@@ -67,6 +67,7 @@
${JIT_TEST_ROOT}/test_irparser.cpp
${JIT_TEST_ROOT}/test_jit_type.cpp
${JIT_TEST_ROOT}/test_lite_interpreter.cpp
+ ${JIT_TEST_ROOT}/test_lite_interpreter_direct.cpp
${JIT_TEST_ROOT}/test_lite_trainer.cpp
${JIT_TEST_ROOT}/test_memory_dag.cpp
${JIT_TEST_ROOT}/test_misc.cpp
diff --git a/test/cpp/jit/test_lite_interpreter_direct.cpp b/test/cpp/jit/test_lite_interpreter_direct.cpp
new file mode 100644
index 0000000..31055e5
--- /dev/null
+++ b/test/cpp/jit/test_lite_interpreter_direct.cpp
@@ -0,0 +1,921 @@
+#include <test/cpp/jit/test_utils.h>
+
+#include <gtest/gtest.h>
+
+#include <c10/core/TensorOptions.h>
+#include <torch/csrc/autograd/generated/variable_factories.h>
+#include <torch/csrc/jit/api/module.h>
+#include <torch/csrc/jit/frontend/resolver.h>
+#include <torch/csrc/jit/mobile/backport.h>
+#include <torch/csrc/jit/mobile/backport_manager.h>
+#include <torch/csrc/jit/mobile/import.h>
+#include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/model_compatibility.h>
+#include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
+#include <torch/csrc/jit/mobile/parse_operators.h>
+#include <torch/csrc/jit/mobile/runtime_compatibility.h>
+#include <torch/csrc/jit/serialization/export.h>
+#include <torch/csrc/jit/serialization/export_bytecode.h>
+#include <torch/csrc/jit/serialization/import.h>
+#include <torch/custom_class.h>
+#include <torch/torch.h>
+
+#include <unordered_set>
+
+// Tests go in torch::jit
+namespace torch {
+namespace jit {
+
+TEST(LiteInterpreterDirectTest, UpsampleNearest2d) {
+ Module m("m");
+ m.define(R"(
+ def forward(self, input: Tensor, scale:float):
+ return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
+ )");
+
+ std::vector<IValue> inputs;
+ inputs.emplace_back(torch::rand({1, 3, 128, 128}));
+ inputs.emplace_back(at::Scalar(2.0));
+ auto ref = m.forward(inputs);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ res = bc.forward(inputs);
+
+ auto resd = res.toTensor();
+ auto refd = ref.toTensor();
+ ASSERT_TRUE(resd.equal(refd));
+}
+
+TEST(LiteInterpreterDirectTest, CheckAttrAccess) {
+ Module m("m");
+ m.register_attribute("mobile_optimized", BoolType::get(), true);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
+
+ AT_ASSERT(mobile_optimized);
+ m.setattr("mobile_optimized", false);
+ bc = jitModuleToMobile(m, options);
+ mobile_optimized = bc.attr("mobile_optimized", false).toBool();
+ AT_ASSERT(!mobile_optimized);
+}
+
+TEST(
+ LiteInterpreterDirectTest,
+ MethodInvocation) { // NOLINT (use =delete in gtest)
+ const std::vector<std::string> test_programs{
+ // test invoking a method with default parameter
+ R"(
+ def test_func(self, x, b : int = 4):
+ return self.foo + x + b
+ )",
+ // inner method call with default parameter (gets inlined)
+ R"(
+ def add_with_default_arg(self, x, b : int = 4):
+ return self.foo + x + b
+ def test_func(self, x):
+ return self.add_with_default_arg(x) # invoke method w/ default arg
+ )",
+ // simple method call
+ R"(
+ def test_func(self, x):
+ b = 4
+ return self.foo + x + b
+ )",
+ };
+ for (const auto& test_program : test_programs) {
+ Module m("m");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(test_program);
+
+ const int fortyTwo = 42; // (keep linter happy)
+ auto minput = fortyTwo * torch::ones({});
+ auto ref = m.run_method("test_func", minput);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ const auto& test_func = bc.get_method("test_func");
+ std::cerr << "hello " << std::endl;
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ res = test_func({minput});
+ }
+ std::cerr << "hello 3" << std::endl;
+
+ auto resd = res.toTensor().item<float>();
+ auto refd = ref.toTensor().item<float>();
+ AT_ASSERT(resd == refd);
+ }
+}
+
+TEST(LiteInterpreterDirectTest, Conv) {
+ auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
+ if (s && strcmp(s, "1") == 0)
+ return;
+
+ std::vector<torch::jit::IValue> inputs;
+
+ Module m("m");
+ m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
+ m.register_parameter("bias", torch::ones({20}), false);
+ m.define(R"(
+ def forward(self, input):
+ return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
+ )");
+
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
+ inputs.push_back(torch::ones({1, 1, 28, 28}));
+
+ auto outputref = m.forward(inputs).toTensor();
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ res = bc.get_method("forward")(inputs);
+ }
+ auto output = res.toTensor();
+ AT_ASSERT(outputref.dim() == output.dim());
+ AT_ASSERT(
+ outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
+}
+
+TEST(LiteInterpreterDirectTest, Inline) {
+ Module m("m");
+ m.define(R"JIT(
+ def foo1(self, x):
+ return x + 1
+
+ def foo2(self, x):
+ return self.foo1(x) + 2
+
+ def foo3(self, x):
+ return self.foo2(x) + 3
+ )JIT");
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ std::vector<torch::jit::IValue> inputs({torch::ones({})});
+ auto output = bc.get_method("foo3")(inputs);
+ AT_ASSERT(output.toTensor().item<float>() == 7.0);
+}
+
+TEST(LiteInterpreterDirectTest, Tuple) {
+ Module m("m");
+ m.define(R"JIT(
+ def foo(self, x):
+ return (1, 2, x + 3)
+
+ def forward(self, x):
+ tuple = self.foo(x)
+ return tuple
+ )JIT");
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ std::vector<torch::jit::IValue> inputs({torch::ones({})});
+ auto output = bc.get_method("forward")(inputs);
+ AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
+}
+
+TEST(LiteInterpreterDirectTest, Dict) {
+ Module m("m");
+ m.define(R"JIT(
+ def foo(self, x):
+ return {"result": x + 1}
+
+ def forward(self, x):
+ d = self.foo(x)
+ return d
+ )JIT");
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ std::vector<torch::jit::IValue> inputs({torch::ones({})});
+ auto output = bc.get_method("forward")(inputs);
+ AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
+}
+
+TEST(LiteInterpreterDirectTest, Prim) {
+ Module m("m");
+ m.define(R"JIT(
+ def forward(self, x):
+ return int(x)
+ )JIT");
+
+ std::vector<IValue> inputs;
+ auto minput = 3.5 * torch::ones({});
+ inputs.emplace_back(minput);
+ auto ref = m.run_method("forward", minput);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+ auto bcinputs = inputs;
+ res = bc.get_method("forward")(bcinputs);
+ }
+
+ auto resi = res.toInt();
+ auto refi = ref.toInt();
+ AT_ASSERT(resi == refi);
+}
+
+TEST(LiteInterpreterDirectTest, PrimScalar) {
+ Module m("m");
+ m.define(R"JIT(
+ def forward(self, x):
+ return int(x.item())
+ )JIT");
+
+ std::vector<IValue> inputs;
+ auto minput = 3.5 * torch::ones({});
+ inputs.emplace_back(minput);
+ auto ref = m.run_method("forward", minput);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+ auto bcinputs = inputs;
+ res = bc.get_method("forward")(bcinputs);
+ }
+
+ auto resi = res.toInt();
+ auto refi = ref.toInt();
+ AT_ASSERT(resi == refi);
+}
+
+TEST(LiteInterpreterDirectTest, WrongMethodName) {
+ Module m("m");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(R"(
+ def add(self, x):
+ b = 4
+ return self.foo + x + b
+ )");
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ std::vector<IValue> inputs;
+ auto minput = 5 * torch::ones({});
+ inputs.emplace_back(minput);
+ ASSERT_THROWS_WITH_MESSAGE(
+ bc.get_method("forward")(inputs), "is not defined");
+}
+
+TEST(LiteInterpreterDirectTest, SetState) {
+ Module m("m");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(R"(
+ def __getstate__(self):
+ return self.foo
+ def __setstate__(self, a):
+ self.foo = a
+ def forward(self, x):
+ b = 4
+ return self.foo + x + b
+ )");
+
+ std::vector<IValue> inputs;
+ auto minput = 5 * torch::ones({});
+ inputs.emplace_back(minput);
+
+ std::stringstream ms;
+ m.save(ms);
+ auto loaded_m = load(ms);
+ auto ref = loaded_m.run_method("forward", minput);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+ auto bcinputs = inputs;
+ res = bc.get_method("forward")(bcinputs);
+ }
+
+ auto resd = res.toTensor().item<float>();
+ auto refd = ref.toTensor().item<float>();
+ AT_ASSERT(resd == refd);
+}
+
+class TorchBindLiteInterpreterDirectTestStruct
+ : public torch::jit::CustomClassHolder {
+ public:
+ std::string get(at::Tensor t) {
+ std::stringstream ss;
+ ss << "Hello! Your tensor has ";
+ ss << t.numel();
+ ss << " elements!";
+ return ss.str();
+ }
+};
+
+namespace {
+struct ClassNamespaceValue : public SugaredValue {
+ explicit ClassNamespaceValue(c10::QualifiedName name)
+ : basename_(std::move(name)) {}
+
+ std::shared_ptr<SugaredValue> attr(
+ const SourceRange&,
+ GraphFunction&,
+ const std::string& name) override {
+ const auto fullName = c10::QualifiedName(basename_, name);
+
+ // Check to see if it is a custom class.
+ if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
+ return std::make_shared<ClassValue>(custom_class);
+ }
+
+ // If it's not a custom class, assume it's another namespace
+ // NOLINTNEXTLINE(performance-move-const-arg)
+ return std::make_shared<ClassNamespaceValue>(fullName);
+ }
+
+ std::string kind() const override {
+ return "Class Namespace";
+ }
+
+ private:
+ c10::QualifiedName basename_;
+};
+
+struct TestModuleResolver : public Resolver {
+ std::shared_ptr<SugaredValue> resolveValue(
+ const std::string& name,
+ GraphFunction&,
+ const SourceRange&) override {
+ if (name == "torch") {
+ return std::make_shared<BuiltinModule>("aten");
+ } else if (name == "__torch__") {
+ return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
+ }
+
+ return nullptr;
+ }
+
+ TypePtr resolveType(const std::string&, const SourceRange&) override {
+ return nullptr;
+ }
+};
+} // namespace
+
+TEST(LiteInterpreterDirectTest, BuiltinFunction) {
+ script::Module m("m");
+ auto custom_class_obj =
+ make_custom_class<TorchBindLiteInterpreterDirectTestStruct>();
+ m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
+ m.define(R"(
+ def forward(self, x) -> str:
+ return self.my_obj.get(x)
+ )");
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ auto res =
+ bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
+ // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+ auto str = res.toStringRef();
+ std::string expected = "Hello! Your tensor has 12 elements!";
+ AT_ASSERT(str == expected);
+}
+
+#if !defined FB_XPLAT_BUILD
+TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) {
+ auto runtime_bytecode_version = _get_runtime_bytecode_version();
+ AT_ASSERT(
+ runtime_bytecode_version ==
+ caffe2::serialize::kMaxSupportedBytecodeVersion);
+}
+
+TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) {
+ auto runtime_operators_version = _get_runtime_operators_min_max_versions();
+ AT_ASSERT(
+ runtime_operators_version.first ==
+ caffe2::serialize::kMinSupportedFileFormatVersion &&
+ runtime_operators_version.second ==
+ caffe2::serialize::kMaxSupportedFileFormatVersion);
+}
+
+/**
+ * The test below is disarmed for FB internal xplat builds since
+ * BUCK requires us to pass in the script_module_v4.ptl file in
+ * as a resource dependency of the build rule for this file, and
+ * we would need to access it via the C++ Resources API instead
+ * of directly reading from disk (which is what the open source
+ * build/run does).
+ */
+TEST(LiteInterpreterDirectTest, GetByteCodeVersion) {
+ std::string filePath(__FILE__);
+ auto test_model_file_v4 =
+ filePath.substr(0, filePath.find_last_of("/\\") + 1);
+ test_model_file_v4.append("script_module_v4.ptl");
+
+ auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
+ AT_ASSERT(version_v4 == 4);
+}
+
+#endif // !defined(FB_XPLAT_BUILD)
+
+TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) {
+ auto runtime_ops = _get_runtime_ops_and_info();
+ // Ballpark estimate of the minimal number of ops; just used to
+ // verify API returns a reasonably large number.
+ AT_ASSERT(runtime_ops.size() > 2900);
+}
+
+TEST(LiteInterpreterDirectTest, Eval) {
+ std::vector<torch::jit::IValue> inputs;
+
+ Module m("m");
+ m.define(R"(
+ def __init__(self, x):
+ self.training = True
+
+ def forward(self, input):
+ return torch.dropout(input, 1.0, self.training)
+ )");
+
+ // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
+ inputs.push_back(torch::ones({1, 1, 28, 28}));
+ m.eval();
+ auto outputref = m.forward(inputs).toTensor();
+
+ // save m in training mode to make sure that mobile eval() will correctly
+ // change back to eval mode
+ m.train();
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ bc.eval();
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ res = bc.get_method("forward")(inputs);
+ }
+ auto output = res.toTensor();
+ AT_ASSERT(outputref.dim() == output.dim());
+ AT_ASSERT(
+ outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
+}
+
+TEST(LiteInterpreterDirectTest, FindWrongMethodName) {
+ Module m("m");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(R"(
+ def add(self, x):
+ b = 4
+ return self.foo + x + b
+ )");
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
+}
+
+TEST(LiteInterpreterDirectTest, FindAndRunMethod) {
+ Module m("m");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(R"(
+ def add_it(self, x):
+ b = 4
+ return self.foo + x + b
+ )");
+
+ std::vector<IValue> inputs;
+ auto minput = 5 * torch::ones({});
+ inputs.emplace_back(minput);
+ auto ref = m.get_method("add_it")(inputs);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ auto bcinputs = inputs;
+ auto method = bc.find_method("add_it");
+ AT_ASSERT(method != c10::nullopt);
+ res = (*method)(std::move(bcinputs));
+ }
+
+ auto resd = res.toTensor().item<float>();
+ auto refd = ref.toTensor().item<float>();
+ AT_ASSERT(resd == refd);
+}
+
+TEST(LiteInterpreterDirectTest, RunMethodVariadic) {
+ Module m("m");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(R"(
+ def add_three(self, x, y):
+ return self.foo + x + y
+ )");
+
+ std::vector<IValue> inputs;
+ auto inputx = 5 * torch::ones({});
+ auto inputy = 4 * torch::ones({});
+ auto ref = m.run_method("add_three", inputx, inputy);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res = bc.run_method("add_three", inputx, inputy);
+
+ auto resd = res.toTensor().item<float>();
+ auto refd = ref.toTensor().item<float>();
+ AT_ASSERT(resd == refd);
+}
+
+TEST(LiteInterpreterDirectTest, DuplicateSetState) {
+ Module m("M");
+ m.register_parameter("foo", torch::ones({}), false);
+ m.define(R"(
+ def __getstate__(self):
+ return self.foo + self.foo
+ def __setstate__(self, a):
+ self.foo = a
+ def forward(self, x):
+ b = 4
+ return self.foo + x + b
+ )");
+
+ Module b("B");
+ b.register_module("M0", m);
+ b.register_module("M1", m);
+ b.define(R"(
+ def forward(self, x):
+ return self.M0.forward(x) + self.M1.forward(x)
+ )");
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ const auto methods = bc.get_methods();
+ const size_t expected_n = 3;
+ ASSERT_EQ(methods.size(), expected_n);
+}
+
+TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) {
+ torch::jit::Module m("m");
+ m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
+ m.register_parameter("bias", torch::ones({20}), false);
+ m.define(R"(
+ def forward(self, input):
+ x1 = torch.zeros(2, 2)
+ x2 = torch.empty_like(torch.empty(2, 2))
+ x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
+ return (x1, x2, x3)
+ )");
+ m.eval();
+
+ CompilationOptions options;
+ mobile::Module ptl_model = jitModuleToMobile(m, options);
+ std::set<std::string> operator_names =
+ torch::jit::mobile::_export_operator_list(ptl_model);
+ std::set<std::string> expected_operator_names = {
+ "aten::_convolution",
+ "aten::empty.memory_format",
+ "aten::empty_like",
+ "aten::zeros",
+ };
+ EXPECT_EQ(operator_names, expected_operator_names)
+ << "Expected the root operator lists to be the same";
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsConv) {
+ auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
+ if (s && strcmp(s, "1") == 0)
+ return;
+
+ std::vector<torch::jit::IValue> inputs;
+
+ Module m("m");
+ m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
+ m.register_parameter("bias", torch::ones({20}), false);
+ m.define(R"(
+ def forward(self, input):
+ return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
+ )");
+
+ inputs.emplace_back(torch::ones({1, 1, 28, 28}));
+
+ auto outputref = m.forward(inputs).toTensor();
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ for (int i = 0; i < 1; ++i) {
+ res = bc.get_method("forward")(inputs);
+ }
+ auto output = res.toTensor();
+ AT_ASSERT(outputref.dim() == output.dim());
+ AT_ASSERT(output.equal(outputref));
+}
+
+namespace {
+void testLiteModuleCompareResultTensors(
+ Module& m,
+ const std::vector<torch::jit::IValue>& inputs,
+ const std::string& method_name = "forward") {
+ auto outputref = m.get_method(method_name)(inputs).toTensor();
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ IValue res;
+ for (int i = 0; i < 3; ++i) {
+ res = bc.get_method(method_name)(inputs);
+ }
+ auto output = res.toTensor();
+ AT_ASSERT(outputref.dim() == output.dim());
+ AT_ASSERT(output.equal(outputref));
+}
+
+void testDefaultArgsPinv2(int num_args) {
+ Module m("m");
+ if (num_args == 1) {
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_pinv(input)
+ )");
+ } else if (num_args == 2) {
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_pinv(input, 1e-5)
+ )");
+ } else if (num_args == 3) {
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_pinv(input, 1e-5, True)
+ )");
+ }
+
+ std::vector<torch::jit::IValue> inputs;
+ const int N = 28;
+ auto input = torch::range(1, N * N, 1);
+ input[0] = 1; // a more stable matrix
+ input = input.view({N, N});
+ inputs.emplace_back(input);
+ testLiteModuleCompareResultTensors(m, inputs);
+}
+} // namespace
+
+#if !defined FB_XPLAT_BUILD
+TEST(LiteInterpreterDirectTest, DefaultArgsPinv) {
+ // Test with different number of specified arguments.
+ // Arguments not specified take default value.
+ for (int num_args = 1; num_args <= 3; ++num_args) {
+ testDefaultArgsPinv2(num_args);
+ }
+
+ // bytecode with one specified argument:
+ // (6,
+ // ('__torch__.m.forward',
+ // (('instructions',
+ // (('STOREN', 1, 2),
+ // ('DROPR', 1, 0),
+ // ('MOVE', 2, 0),
+ // ('OP', 0, 0),
+ // ('RET', 0, 0))),
+ // ('operators', (('aten::linalg_pinv', '', 1),)),
+ // ('constants', (False, 1e-15)), # default constants are not
+ // used
+ // ('types', ()),
+ // ('register_size', 2)),
+ // (('arguments',
+ // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
+ // None)),
+ // (('name', 'input'), ('type', 'Tensor'), ('default_value',
+ // None)))),
+ // ('returns',
+ // ((('name', ''), ('type', 'Tensor'), ('default_value',
+ // None)),)))))
+
+ // bytecode with 2 specified argument:
+ // (6,
+ // ('__torch__.m.forward',
+ // (('instructions',
+ // (('STOREN', 1, 2),
+ // ('DROPR', 1, 0),
+ // ('MOVE', 2, 0),
+ // ('LOADC', 1, 0), # added LOADC for specified argument
+ // ('OP', 0, 0),
+ // ('RET', 0, 0))),
+ // ('operators', (('aten::linalg_pinv', '', 2),)),
+ // ('constants', (False, 1e-05)), # updated constant table
+ // ('types', ()),
+ // ('register_size', 2)),
+ // (('arguments',
+ // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
+ // None)),
+ // (('name', 'input'), ('type', 'Tensor'), ('default_value',
+ // None)))),
+ // ('returns',
+ // ((('name', ''), ('type', 'Tensor'), ('default_value',
+ // None)),)))))
+
+ // bytecode with 3 specified arguments:
+ // (6,
+ // ('__torch__.m.forward',
+ // (('instructions',
+ // (('STOREN', 1, 2),
+ // ('DROPR', 1, 0),
+ // ('MOVE', 2, 0),
+ // ('LOADC', 1, 0),
+ // ('LOADC', 0, 0),
+ // ('OP', 0, 0),
+ // ('RET', 0, 0))),
+ // ('operators', (('aten::linalg_pinv', '', 3),)),
+ // ('constants', (True, 1e-05)),
+ // ('types', ()),
+ // ('register_size', 2)),
+ // (('arguments',
+ // ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
+ // None)),
+ // (('name', 'input'), ('type', 'Tensor'), ('default_value',
+ // None)))),
+ // ('returns',
+ // ((('name', ''), ('type', 'Tensor'), ('default_value',
+ // None)),)))))
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) {
+ // The second argument is specified, but the value is the same as the default
+ // value. It's treated as "not specified" since the value can be fetched from
+ // schema.
+ Module m("m");
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_tensorinv(input, 2)
+ )");
+ torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
+ auto arg_nums = code.op_to_num_specified_args();
+ ASSERT_EQ(arg_nums.size(), 1);
+ ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
+ std::vector<torch::jit::IValue> inputs;
+ const int N = 4;
+ auto input = torch::rand({N, N, N, N});
+ inputs.emplace_back(input);
+ testLiteModuleCompareResultTensors(m, inputs);
+}
+
+void testDefaultArgsPinvWithOutArg2(int num_args) {
+ Module m("m");
+ if (num_args == 1) {
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_pinv(input, out=input)
+ )");
+ } else if (num_args == 2) {
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_pinv(input, 1e-5, out=input)
+ )");
+ } else if (num_args == 3) {
+ m.define(R"(
+ def forward(self, input):
+ return torch.linalg_pinv(input, 1e-5, True, out=input)
+ )");
+ }
+
+ const int N = 28;
+ auto input = torch::range(1, N * N, 1);
+ input[0] = 10000; // a more stable matrix
+ input = input.view({N, N});
+ auto ref = m.run_method("forward", input);
+ TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
+ TORCH_CHECK(input.equal(ref.toTensor()));
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) {
+ // Test with different number of specified arguments + out arg.
+ // Arguments not specified take default value.
+ for (int num_args = 1; num_args <= 3; ++num_args) {
+ testDefaultArgsPinvWithOutArg2(num_args);
+ }
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) {
+ Module m("m");
+ m.define(R"(
+ def forward(self, x, h):
+ torch.add(x, h, out=x)
+ )");
+
+ std::vector<IValue> inputs;
+ auto input_x = 2 * torch::ones({});
+ auto input_h = torch::ones({});
+ auto ref = m.run_method("forward", input_x, input_h);
+
+ CompilationOptions options;
+ mobile::Module bc = jitModuleToMobile(m, options);
+ bc.run_method("forward", input_x, input_h);
+ AT_ASSERT(input_x.equal(4 * torch::ones({})));
+}
+
+TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
+ Module a("A");
+ a.define(R"(
+ def bar(self, x, y):
+ return x + y
+ )");
+ Module b("B");
+ b.register_module("A0", a);
+ b.define(R"(
+ def foo(self, x, y):
+ return self.A0.bar(x, y) + 2
+ )");
+ Module c("C");
+ c.register_module("B0", b);
+ c.define(R"(
+ def forward(self, x, y):
+ return self.B0.foo(x, y) + 3
+ )");
+
+ std::vector<IValue> inputs;
+ inputs.emplace_back(torch::rand({2, 4}));
+ inputs.emplace_back(torch::rand({13, 9}));
+
+ CompilationOptions options;
+ auto lite_m = jitModuleToMobile(c, options);
+ std::string error_pattern = R"(
+ Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
+Traceback of TorchScript (most recent call last):
+ File "<string>", line 3, in <unknown>
+
+ def forward(self, x, y):
+ return self.B0.foo(x, y) + 3
+ ~~~~~~~~~~~ <--- HERE
+
+ File "<string>", line 3, in foo
+
+ def foo(self, x, y):
+ return self.A0.bar(x, y) + 2
+ ~~~~~~~~~~~ <--- HERE
+
+ File "<string>", line 3, in bar
+
+ def bar(self, x, y):
+ return x + y
+ ~~~~~ <--- HERE
+ )";
+ ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
+}
+#endif // !defined(FB_XPLAT_BUILD)
+
+namespace {
+static auto reg =
+ torch::class_<TorchBindLiteInterpreterDirectTestStruct>(
+ "_TorchScriptTesting",
+ "_LiteInterpreterDirectTest")
+ .def(torch::init<>())
+ .def("get", &TorchBindLiteInterpreterDirectTestStruct::get)
+ .def_pickle(
+ // __getattr__
+ [](const c10::intrusive_ptr<
+ TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t {
+ return 0;
+ },
+ // __setattr__
+ [](int64_t) {
+ return c10::make_intrusive<
+ TorchBindLiteInterpreterDirectTestStruct>();
+ });
+
+} // namespace
+
+TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) {
+ // Create 3 methods:
+ //
+ // 1. forward() returns a tensor with dtype=torch.int64 (4)
+ // 2. forward2() returns a tensor with dtype=torch.float32 (6)
+ // 3. forward3() returns a tensor with dtype=torch.float32 but
+ // the dtype is inferred by the input tensor's dtype
+ //
+ // If caching works correctly, then the result from the full-jit
+ // module and the lite module will be the same. Otherwise, it
+ // will be different if we don't correctly ignore the cache
+ // entry for an operator that has a different number of
+ // arguments.
+ Module m("m");
+ m.define(R"(
+ def forward(self):
+ ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
+ return ret1.fill_(25)
+ )");
+ m.define(R"(
+ def forward2(self):
+ ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
+ return ret1.fill_(32.0)
+ )");
+ m.define(R"(
+ def forward3(self):
+ ret1 = torch.new_empty(torch.zeros(10), [10])
+ return ret1.fill_(12.0)
+ )");
+
+ std::vector<torch::jit::IValue> inputs;
+ testLiteModuleCompareResultTensors(m, inputs, "forward");
+ testLiteModuleCompareResultTensors(m, inputs, "forward2");
+ testLiteModuleCompareResultTensors(m, inputs, "forward3");
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h
index 5fd20a9..27e61b8 100644
--- a/torch/csrc/jit/api/function_impl.h
+++ b/torch/csrc/jit/api/function_impl.h
@@ -24,6 +24,10 @@
void run(Stack& stack) override;
+ std::function<void(GraphFunction&)> function_creator() const {
+ return function_creator_;
+ }
+
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
Stack& stack,
TaskLauncher taskLauncher = at::launch) override;
diff --git a/torch/csrc/jit/mobile/code.h b/torch/csrc/jit/mobile/code.h
index 5ea6617..c91a76b 100644
--- a/torch/csrc/jit/mobile/code.h
+++ b/torch/csrc/jit/mobile/code.h
@@ -20,6 +20,7 @@
std::vector<Instruction> instructions_;
std::vector<DebugHandle> debug_handles_;
std::vector<c10::OperatorName> op_names_;
+ std::vector<int> operator_input_sizes_;
std::vector<std::function<void(Stack&)>> operators_;
std::vector<c10::IValue> constants_;
std::vector<c10::TypePtr> types_;
diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h
index 444573c..bbeb4ed 100644
--- a/torch/csrc/jit/mobile/debug_info.h
+++ b/torch/csrc/jit/mobile/debug_info.h
@@ -23,6 +23,10 @@
MobileDebugTable(
std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
const std::shared_ptr<CompilationUnit>& cu);
+
+ template <typename It>
+ MobileDebugTable(It begin, It end) : callstack_ptr_map_(begin, end) {}
+
std::string getSourceDebugString(
const int64_t debug_handle,
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
@@ -36,6 +40,11 @@
const std::vector<int64_t>& debug_handles,
const std::string& top_module_type_name = "ModuleTypeUnknown") const;
+ const ska::flat_hash_map<int64_t, DebugInfoTuple>& getCallStackPtrMap()
+ const {
+ return callstack_ptr_map_;
+ }
+
private:
std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo(
const std::vector<int64_t>& debug_handles,
diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp
index bc1bd53..8af058e 100644
--- a/torch/csrc/jit/mobile/function.cpp
+++ b/torch/csrc/jit/mobile/function.cpp
@@ -13,6 +13,14 @@
Function::Function(c10::QualifiedName name)
: name_(std::move(name)), code_(std::make_shared<Code>()) {}
+Function::Function(
+ c10::QualifiedName name,
+ std::shared_ptr<Code> code,
+ at::optional<c10::FunctionSchema> schema)
+ : name_(std::move(name)),
+ code_(std::move(code)),
+ schema_(std::move(schema)) {}
+
const c10::QualifiedName& Function::qualname() const {
return name_;
}
@@ -43,89 +51,11 @@
// Keep the original opname in code_
code_->op_names_.emplace_back(name, overload_name);
const auto& opname = code_->op_names_.back();
- const auto full_name = c10::toString(opname);
-
- std::function<void(Stack&)> fn;
-
- const std::vector<c10::Argument>* pArgs = nullptr;
- bool promoted_op = mobile::hasPrimOpsFn(full_name);
- if (promoted_op) {
- fn = mobile::getPrimOpsFn(full_name);
- } else {
- std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
- if (jit_op) {
- fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
- pArgs = &jit_op->schema().arguments();
- } else {
- auto op = c10::Dispatcher::singleton().findSchema(opname);
- if (op.has_value()) {
- fn = [op](Stack& stack) { op->callBoxed(&stack); };
- if (op->hasSchema()) {
- pArgs = &op->schema().arguments();
- } else {
- TORCH_CHECK(false, "arguments are missing for operator ", opname);
- }
- } else {
- return false;
- }
- }
+ auto func = makeOperatorFunction(opname, num_specified_args, model_version);
+ if (!func.has_value()) {
+ return false;
}
-
- if (!promoted_op) {
- TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
- const auto& args = *pArgs;
- if (model_version == 0x3LL && opname.name == "aten::_convolution" &&
- opname.overload_name.empty()) {
- // Since byte-code versions 0x4L, convolution has an additional
- // default-value argument (allow_tf32=True, see
- // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
- // backward compatibility with models of byte-code version <= 0x3L, where
- // this bool argument does not yet exist.
- fn = [fn](Stack& stack) {
- stack.push_back(true);
- fn(stack);
- };
- } else {
- // num_specified_args >= 0 indicates number of arguments are available
- // from model. We can use it to handle backward compatibility.
- if (num_specified_args &&
- num_specified_args.value() < static_cast<int64_t>(args.size())) {
- fn = [fn, num_specified_args, &args](Stack& stack) {
- std::vector<IValue> out_args;
- // The following logic pops and temporarily stores all out arguments
- // from the stack (which can be 0 or more, and always appended to the
- // schema), in order to push the necessary default values. Finally,
- // the out arguments are pushed back into the stack.
- for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
- out_args.push_back(stack.back());
- stack.pop_back();
- }
- size_t start_index = num_specified_args.value() - out_args.size();
- TORCH_CHECK(
- start_index >= 0,
- "The number of output arguments is: ",
- out_args.size(),
- ", which is more then the number of specified arguments: ",
- num_specified_args.value());
- for (size_t i = start_index; i < (args.size() - out_args.size());
- ++i) {
- TORCH_CHECK(
- args[i].default_value().has_value(),
- "Error happened at preparing for default values for the argument. The ",
- i,
- "th argument ",
- args[i].name(),
- " does not have a specified value or default value. ");
-
- stack.push_back(args[i].default_value());
- }
- stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
- fn(stack);
- };
- }
- }
- }
- code_->operators_.emplace_back(fn);
+ code_->operators_.emplace_back(*func);
return true;
}
@@ -197,6 +127,93 @@
return getInterpretersExceptionDebugHandles();
}
+c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
+ c10::OperatorName opname,
+ c10::optional<int> num_specified_args,
+ int64_t model_version) {
+ std::function<void(Stack&)> fn;
+ const auto full_name = c10::toString(opname);
+ const std::vector<c10::Argument>* pArgs = nullptr;
+ bool promoted_op = mobile::hasPrimOpsFn(full_name);
+ if (promoted_op) {
+ fn = mobile::getPrimOpsFn(full_name);
+ } else {
+ std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
+ if (jit_op) {
+ fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
+ pArgs = &jit_op->schema().arguments();
+ } else {
+ auto op = c10::Dispatcher::singleton().findSchema(opname);
+ if (op.has_value()) {
+ fn = [op](Stack& stack) { op->callBoxed(&stack); };
+ if (op->hasSchema()) {
+ pArgs = &op->schema().arguments();
+ } else {
+ TORCH_CHECK(false, "arguments are missing for operator ", opname);
+ }
+ } else {
+ return c10::nullopt;
+ }
+ }
+ }
+
+ if (!promoted_op) {
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
+ const auto& args = *pArgs;
+ if (model_version == 0x3LL && opname.name == "aten::_convolution" &&
+ opname.overload_name.empty()) {
+ // Since byte-code versions 0x4L, convolution has an additional
+ // default-value argument (allow_tf32=True, see
+ // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
+ // backward compatibility with models of byte-code version <= 0x3L, where
+ // this bool argument does not yet exist.
+ fn = [fn](Stack& stack) {
+ stack.push_back(true);
+ fn(stack);
+ };
+ } else {
+ // num_specified_args >= 0 indicates number of arguments are available
+ // from model. We can use it to handle backward compatibility.
+ if (num_specified_args &&
+ num_specified_args.value() < static_cast<int64_t>(args.size())) {
+ fn = [fn, num_specified_args, &args](Stack& stack) {
+ std::vector<IValue> out_args;
+ // The following logic pops and temporarily stores all out arguments
+ // from the stack (which can be 0 or more, and always appended to the
+ // schema), in order to push the necessary default values. Finally,
+ // the out arguments are pushed back into the stack.
+ for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
+ out_args.push_back(stack.back());
+ stack.pop_back();
+ }
+ size_t start_index = num_specified_args.value() - out_args.size();
+ TORCH_CHECK(
+ start_index >= 0,
+ "The number of output arguments is: ",
+ out_args.size(),
+ ", which is more then the number of specified arguments: ",
+ num_specified_args.value());
+ for (size_t i = start_index; i < (args.size() - out_args.size());
+ ++i) {
+ TORCH_CHECK(
+ args[i].default_value().has_value(),
+ "Error happened at preparing for default values for the argument. The ",
+ i,
+ "th argument ",
+ args[i].name(),
+ " does not have a specified value or default value. ");
+
+ stack.push_back(args[i].default_value());
+ }
+ stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
+ fn(stack);
+ };
+ }
+ }
+ }
+ return fn;
+}
+
} // namespace mobile
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h
index 5e20515..c29638a 100644
--- a/torch/csrc/jit/mobile/function.h
+++ b/torch/csrc/jit/mobile/function.h
@@ -17,6 +17,10 @@
class TORCH_API Function : public torch::jit::Function {
public:
explicit Function(c10::QualifiedName name);
+ Function(
+ c10::QualifiedName name,
+ std::shared_ptr<Code> code,
+ at::optional<c10::FunctionSchema> schema);
void run(Stack& stack) override;
at::IValue operator()(Stack& stack);
void ensure_defined() override {}
@@ -24,6 +28,9 @@
const c10::QualifiedName& qualname() const override;
bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
+ // NOTE: the APIs below is dangerous: if you call append_instruction with
+ // dbg_handle and then call it without; then the dbg_handle will become
+ // misaligned. Therefore only use ONE variant at time.
void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
void append_instruction(OpCode op, int X, int N);
bool append_operator(
@@ -56,6 +63,11 @@
at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
};
+c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
+ c10::OperatorName opname,
+ c10::optional<int> num_specified_args,
+ int64_t model_version);
+
} // namespace mobile
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp
index 99e60c4..fddd3aa 100644
--- a/torch/csrc/jit/mobile/interpreter.cpp
+++ b/torch/csrc/jit/mobile/interpreter.cpp
@@ -94,15 +94,15 @@
debug_handle = *handle;
}
- // std::cout << "RUNNING " << pc << " "
- // << code_->instructions_with_handles_[pc].instruction;
+ // std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
// if (inst.op == OP) {
- // std::cout << ", " << code_->op_names_[inst.X].name;
- // if (!code_->op_names_[inst.X].overload_name.empty()) {
- // std::cout << "." << code_->op_names_[inst.X].overload_name;
+ // std::cout << ", " << code.op_names_[inst.X].name;
+ // if (!code.op_names_[inst.X].overload_name.empty()) {
+ // std::cout << "." << code.op_names_[inst.X].overload_name;
// }
// }
// std::cout << std::endl;
+ // std::cout << "top " << stack.back().tagKind() << std::endl;
// TODO(iliacher): remove the workaround after RecordFunction is in
// Dispatcher
diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h
index 4d00087..31d3f58 100644
--- a/torch/csrc/jit/mobile/module.h
+++ b/torch/csrc/jit/mobile/module.h
@@ -135,7 +135,7 @@
std::unordered_map<std::string, std::string> metadata_;
std::shared_ptr<CompilationUnit> cu_;
MobileDebugTable debug_table_;
- bool has_debug_handles_;
+ bool has_debug_handles_ = false;
};
} // namespace mobile
} // namespace jit
diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp
index d102a50..32e567e 100644
--- a/torch/csrc/jit/runtime/interpreter.cpp
+++ b/torch/csrc/jit/runtime/interpreter.cpp
@@ -33,7 +33,6 @@
#endif
#include <exception>
-#include <iostream>
#include <memory>
#include <mutex>
#include <ostream>
diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp
index 2e37fa7..a1348dd 100644
--- a/torch/csrc/jit/serialization/export_bytecode.cpp
+++ b/torch/csrc/jit/serialization/export_bytecode.cpp
@@ -1,23 +1,333 @@
#include <torch/csrc/jit/serialization/export_bytecode.h>
+#include <utility>
#include <torch/csrc/jit/runtime/instruction.h>
#include <torch/csrc/jit/serialization/export.h>
+#include <c10/util/Exception.h>
+#include <torch/csrc/jit/api/function_impl.h>
+#include <torch/csrc/jit/api/method.h>
+#include <torch/csrc/jit/backends/backend_debug_handler.h>
+#include <torch/csrc/jit/backends/backend_debug_info.h>
+#include <torch/csrc/jit/frontend/source_range.h>
+#include <torch/csrc/jit/ir/attributes.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/type_hashing.h>
+#include <torch/csrc/jit/mobile/function.h>
+#include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/method.h>
+#include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/passes/inliner.h>
+#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
+#include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
+#include <torch/csrc/jit/serialization/import_export_helpers.h>
+#include <torch/csrc/jit/serialization/pickle.h>
+#include <torch/csrc/jit/serialization/python_print.h>
+#include <torch/csrc/jit/serialization/source_range_serialization.h>
+#include <torch/csrc/jit/serialization/type_name_uniquer.h>
+
+#include <caffe2/serialize/inline_container.h>
+
namespace torch {
namespace jit {
-void BytecodeExportSet::add(
- const c10::QualifiedName& qn,
- ExportedFunction exported) {
- items_.emplace(qn, std::move(exported));
+std::vector<Method> gatherGetSetStates(ObjectPtr obj) {
+ std::vector<Method> methods;
+ // Use DFS on IValue's to traverse dependencies of module._ivalue and
+ // add all setstate/getstates to initial stack.
+ std::vector<ObjectPtr> ivalue_stack;
+ ivalue_stack.emplace_back(obj);
+ while (!ivalue_stack.empty()) {
+ ObjectPtr cur = ivalue_stack.back();
+ ivalue_stack.pop_back();
+ auto type = cur->type();
+ Function* setstate = type->findMethod("__setstate__");
+ Function* getstate = type->findMethod("__getstate__");
+ if (getstate && setstate) {
+ if (setstate->isGraphFunction()) {
+ methods.emplace_back(cur, setstate);
+ }
+ if (getstate->isGraphFunction()) {
+ methods.emplace_back(cur, getstate);
+ }
+ } else {
+ for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
+ IValue field = cur->getSlot(i);
+ if (field.isObject()) {
+ ivalue_stack.emplace_back(field.toObject());
+ }
+ }
+ }
+ }
+ return methods;
}
-void BytecodeExportSet::update(const c10::QualifiedName& qn, bool toplevel) {
- items_.at(qn).toplevel = toplevel;
+std::vector<Method> findAllDependentFunctions(
+ const Module& module,
+ Graph& graph) {
+ std::vector<Method> methods;
+ std::unordered_set<c10::string_view> called_method_names;
+ auto nodes = findAllNodes(graph, c10::prim::CallMethod, true);
+ for (Node* node : nodes) {
+ if (auto iface = node->input(0)->type()->castRaw<InterfaceType>()) {
+ const FunctionSchema* schema = iface->getMethod(node->s(attr::name));
+ called_method_names.insert(schema->name());
+ }
+ }
+
+ for (const auto& submodule : module.modules()) {
+ for (const auto& m : submodule.get_methods()) {
+ if (called_method_names.find(m.function().qualname().name()) !=
+ called_method_names.end()) {
+ methods.emplace_back(m);
+ }
+ }
+ }
+ return methods;
}
-bool BytecodeExportSet::contains(const c10::QualifiedName& qn) const {
- return items_.find(qn) != items_.end();
+// NOTE: order of functions returned will be:
+// 1. functions originated from the methods passed in will be first
+// 2. All the dependent functions will come afterwards.
+// This order is meaningful because currently mobile Module looks up
+// methods with linear search.
+std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
+ const std::vector<Method>& initial_methods,
+ bool incl_dependent_functions) {
+ std::set<std::pair<std::string, Function*>> visited;
+ std::deque<Method> stack;
+ std::copy(
+ initial_methods.begin(),
+ initial_methods.end(),
+ std::back_inserter(stack));
+ std::vector<std::unique_ptr<GraphFunction>> inlined_functions;
+ while (!stack.empty()) {
+ Method cur = stack.front();
+ stack.pop_front();
+ auto tup = std::make_pair(
+ cur.owner()._ivalue()->type()->name()->qualifiedName(),
+ &cur.function());
+ if (visited.find(tup) != visited.end()) {
+ continue;
+ }
+ visited.insert(tup);
+ const auto& f = toGraphFunction(cur.function());
+ auto graph = f.graph()->copyUnique();
+ Inline(*graph);
+ c10::QualifiedName qn(*cur.owner()._ivalue()->type()->name(), f.name());
+
+ if (incl_dependent_functions) {
+ std::vector<Method> dependent_methods =
+ findAllDependentFunctions(cur.owner(), *graph);
+ std::copy(
+ dependent_methods.begin(),
+ dependent_methods.end(),
+ std::back_inserter(stack));
+ }
+ auto inlined_func = std::make_unique<GraphFunction>(
+ qn, std::move(graph), f.function_creator());
+ inlined_func->setSchema(f.getSchema());
+ inlined_functions.emplace_back(std::move(inlined_func));
+ }
+ return inlined_functions;
+}
+
+std::unique_ptr<mobile::Code> compileGraphToMobileCode(
+ const std::string& name,
+ const std::shared_ptr<Graph>& graph,
+ const CompilationOptions& compilation_options,
+ BackendDebugInfoRecorder& debug_info_recorder) {
+ MobileCode code(
+ graph,
+ name,
+ compilation_options.enable_default_value_for_unspecified_arg,
+ compilation_options.enable_default_args_before_out_args);
+
+ std::unique_ptr<mobile::Code> mobile_code_ptr =
+ std::make_unique<mobile::Code>();
+ mobile::Code& mobile_code = *mobile_code_ptr;
+
+ // operator names
+ std::vector<std::string> method_names;
+ std::vector<int64_t> op_debug_handles;
+ int next_new_op_index = 0;
+
+ auto op_to_specified_args = code.op_to_num_specified_args();
+
+ for (size_t i = 0; i < code.instructions().size(); ++i) {
+ Instruction ins = code.instructions()[i];
+
+ if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
+ // Found a new op (assumes new operators ordered by ascending ins.X)
+ auto node = code.instructions_source()[i];
+ const c10::OperatorName& opname = node->schema().operator_name();
+ auto unique_name = c10::toString(opname);
+ // For operator with vararg, adding default arguments would be confusing
+ // and is not allowed. For an operator with num_args = -1, it means the
+ // number of arguments is not available for this operator, we don't do any
+ // backward compatibility adaptation at runtime.
+ c10::optional<int> num_args = c10::nullopt;
+ auto it = op_to_specified_args.find(unique_name);
+ if (it != op_to_specified_args.end()) {
+ num_args = it->second;
+ }
+ mobile_code.operator_input_sizes_.emplace_back(num_args.value_or(-1));
+ mobile_code.op_names_.emplace_back(opname);
+ auto func = mobile::makeOperatorFunction(
+ opname, num_args, compilation_options.model_version);
+ TORCH_INTERNAL_ASSERT(
+ func.has_value(),
+ "Operator with name: ",
+ toString(opname),
+ " not found");
+ mobile_code.operators_.emplace_back(*func);
+ next_new_op_index++;
+ }
+ // CALL nodes at this point represent built-in (i.e. non-Graph)
+ // functions that were not inlined. Here we convert the CALL
+ // instructions for these functions into INTERFACE_CALL instructions
+ // s.t. at runtime, we will look up the Function* on the Type of the
+ // 0th argument in the stack and call that directly.
+ if (ins.op == CALL) {
+ auto node = code.instructions_source()[i];
+ if (node->kind() == prim::CallMethod) {
+ // NB: replacing instruction
+ auto method_name_idx =
+ code.constant_table().size() + method_names.size();
+ method_names.emplace_back(node->s(attr::name));
+ ins = Instruction{
+ INTERFACE_CALL,
+ static_cast<int32_t>(method_name_idx),
+ static_cast<uint16_t>(node->inputs().size())};
+ } else {
+ TORCH_INTERNAL_ASSERT(
+ false, "Unsupported node kind on CALL opcode for mobile");
+ }
+ } else if (ins.op == RET) {
+ auto node = code.instructions_source()[i];
+ for (const auto& input : node->inputs()) {
+ const auto& input_type = input->type();
+ if (input_type->kind() == TypeKind::ListType ||
+ input_type->kind() == TypeKind::DictType) {
+ for (const TypePtr& element_type : input_type->containedTypes()) {
+ TORCH_CHECK(
+ element_type->kind() != TypeKind::ClassType,
+ "Returining a list or dictionary with pytorch class type ",
+ "is not supported in mobile module "
+ "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
+ "Workaround: instead of using pytorch class as their element type, ",
+ "use a combination of list, dictionary, and single types.");
+ }
+ }
+ }
+ } else {
+ TORCH_CHECK(
+ isOpSupportedInMobile(ins.op),
+ toString(ins.op),
+ " is not supported in mobile module.");
+ }
+ auto node = code.instructions_source()[i];
+ int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
+ // Note 1-to-1 correspondence between instructions and debug handles
+ mobile_code.instructions_.emplace_back(ins);
+ mobile_code.debug_handles_.emplace_back(debug_handle);
+ }
+
+ // copy constants
+ mobile_code.constants_ = code.constant_table();
+
+ // Make a copy of the constants and append the method names
+ // that we emitted for the converted INTERFACE_CALL nodes above.
+ for (auto& method_name : method_names) {
+ mobile_code.constants_.emplace_back(method_name);
+ }
+
+ mobile_code.types_ = code.type_table();
+ mobile_code.register_size_ = code.register_size();
+ return mobile_code_ptr;
+}
+
+void checkSchema(const FunctionSchema& schema) {
+ TORCH_CHECK(
+ schema.overload_name().empty(), // @TODO: is this check correct?
+ "Overloads are not supported in mobile modules.");
+ TORCH_CHECK(
+ !schema.is_vararg(), "Python *args are not supported in mobile modules.");
+ TORCH_CHECK(
+ !schema.is_varret(),
+ "A variable number of return values is not supported in mobile modules.");
+}
+
+bool isLoweredModule(const Module& m) {
+ c10::QualifiedName type_name;
+ if (m.type()->name()) {
+ type_name = m.type()->name().value();
+ }
+ bool isLoweredModule = false;
+ for (const auto& atom : type_name.atoms()) {
+ if (atom == "LoweredModule") {
+ isLoweredModule = true;
+ break;
+ }
+ }
+ return isLoweredModule;
+}
+
+// Check if the global static map of backend debug info
+// contains debug info for this module and any of its children.
+// If so combine all the maps together and return one.
+void getBackendDebugInfoMap(
+ const Module& m,
+ BackendDebugInfoMapType& debug_map) {
+ if (isLoweredModule(m)) {
+ auto backend_debug_info =
+ m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
+ const auto& map = backend_debug_info->getDebugInfoMap();
+ if (map) {
+ debug_map.insert(map.value().begin(), map.value().end());
+ }
+ }
+ for (const auto& c : m.children()) {
+ getBackendDebugInfoMap(c, debug_map);
+ }
+}
+
+mobile::Module jitModuleToMobile(
+ const Module& module,
+ const CompilationOptions& options) {
+ std::shared_ptr<mobile::CompilationUnit> mcu =
+ std::make_shared<mobile::CompilationUnit>();
+ BackendDebugInfoRecorder debug_info_recorder;
+
+ std::vector<Method> methods_to_export = module.get_methods();
+ std::vector<Method> getsetstates = gatherGetSetStates(module._ivalue());
+ std::copy(
+ getsetstates.begin(),
+ getsetstates.end(),
+ std::back_inserter(methods_to_export));
+
+ for (const auto& func :
+ inlineFunctions(methods_to_export, options.incl_interface_call)) {
+ std::shared_ptr<mobile::Code> mobile_code_ptr = compileGraphToMobileCode(
+ func->name(), func->graph(), options, debug_info_recorder);
+ const auto& schema = func->getSchema();
+ checkSchema(schema);
+ auto mobile_func = std::make_unique<mobile::Function>(
+ func->qualname(), mobile_code_ptr, schema);
+ mcu->register_function(std::move(mobile_func));
+ }
+
+ mobile::Module m(module._ivalue(), mcu);
+ m.setHasDebugHandles(true);
+ BackendDebugInfoMapType backend_debug_info_map;
+ getBackendDebugInfoMap(module, backend_debug_info_map);
+ auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
+ debug_handle_cs_ptr_map.insert(
+ backend_debug_info_map.begin(), backend_debug_info_map.end());
+ m.setDebugTable(MobileDebugTable(
+ debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
+ return m;
}
} // namespace jit
diff --git a/torch/csrc/jit/serialization/export_bytecode.h b/torch/csrc/jit/serialization/export_bytecode.h
index f97434b..8782b5d 100644
--- a/torch/csrc/jit/serialization/export_bytecode.h
+++ b/torch/csrc/jit/serialization/export_bytecode.h
@@ -1,59 +1,31 @@
#pragma once
+#include <tuple>
#include <unordered_map>
#include <ATen/core/function_schema.h>
#include <ATen/core/ivalue.h>
+#include <ATen/core/jit_type.h>
#include <ATen/core/qualified_name.h>
#include <torch/csrc/jit/backends/backend_debug_handler.h>
+#include <torch/csrc/jit/mobile/function.h>
+#include <torch/csrc/jit/mobile/module.h>
#include <torch/csrc/jit/runtime/interpreter.h>
#include <torch/csrc/jit/serialization/type_name_uniquer.h>
namespace torch {
namespace jit {
-struct ExportedFunction {
- ExportedFunction(
- const Module& m,
- const Function& f,
- std::unique_ptr<Graph> g,
- bool t)
- : mod(m), function(f), optimizedGraph(std::move(g)), toplevel(t) {}
- Module mod;
- const Function& function;
- std::unique_ptr<Graph> optimizedGraph;
- bool toplevel;
+struct TORCH_API CompilationOptions {
+ bool incl_interface_call = false;
+ bool enable_default_value_for_unspecified_arg = false;
+ bool enable_default_args_before_out_args = true;
+ int model_version = caffe2::serialize::kProducedBytecodeVersion;
};
-class TORCH_API BytecodeExportSet {
- public:
- BytecodeExportSet() = default;
- BytecodeExportSet(const BytecodeExportSet&) = delete;
- BytecodeExportSet& operator=(const BytecodeExportSet&) = delete;
- BytecodeExportSet(BytecodeExportSet&&) = default;
- BytecodeExportSet& operator=(BytecodeExportSet&&) = default;
-
- void add(const c10::QualifiedName& qn, ExportedFunction);
- void update(const c10::QualifiedName& qn, bool toplevel);
- bool contains(const c10::QualifiedName& qn) const;
-
- template <typename F>
- void visit(F&& f) {
- for (auto& item : items_) {
- if (item.second.toplevel) {
- f(item.first, item.second);
- }
- }
- for (auto& item : items_) {
- if (!item.second.toplevel) {
- f(item.first, item.second);
- }
- }
- }
-
- private:
- std::unordered_map<c10::QualifiedName, ExportedFunction> items_;
-};
+TORCH_API mobile::Module jitModuleToMobile(
+ const Module& module,
+ const CompilationOptions& options);
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp
index 61d39fa..a13193b 100644
--- a/torch/csrc/jit/serialization/export_module.cpp
+++ b/torch/csrc/jit/serialization/export_module.cpp
@@ -38,6 +38,18 @@
namespace torch {
namespace jit {
+CompilationOptions getOptionsFromGlobal() {
+ CompilationOptions compilation_options;
+ compilation_options.enable_default_args_before_out_args =
+ BytecodeEmitMode::is_default_args_before_out_args_enabled();
+ compilation_options.enable_default_value_for_unspecified_arg =
+ BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
+ compilation_options.incl_interface_call = getMobileInterfaceCallExport();
+ compilation_options.model_version =
+ caffe2::serialize::kProducedBytecodeVersion;
+ return compilation_options;
+}
+
IValue to_tuple(std::initializer_list<IValue> ivalues) {
return c10::ivalue::Tuple::create(ivalues);
}
@@ -63,138 +75,49 @@
}
std::pair<IValue, IValue> getFunctionTuple(
- const Module& module,
- const Function& func,
- std::unique_ptr<Graph> optimizedGraph,
+ const CompilationUnit& compilation_unit,
+ const mobile::Function& func,
BackendDebugInfoRecorder& debug_info_recorder,
- const std::string& qn,
TypeNameUniquer& type_name_uniquer_) {
- TORCH_INTERNAL_ASSERT(optimizedGraph);
- std::shared_ptr<MobileCode> code;
- code = std::make_shared<MobileCode>(
- std::move(optimizedGraph), func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */);
- auto instructions_copy = code->instructions();
-
- // operator names
- std::vector<c10::OperatorName> opnames;
- std::vector<std::string> method_names;
- std::vector<int64_t> op_debug_handles;
- int next_new_op_index = 0;
- for (size_t i = 0; i < instructions_copy.size(); ++i) {
- Instruction ins = instructions_copy[i];
- if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
- // Found a new op (assumes new operators ordered by ascending ins.X)
- auto node = code->instructions_source()[i];
- opnames.emplace_back(node->schema().operator_name());
- next_new_op_index++;
- }
- // CALL nodes at this point represent built-in (i.e. non-Graph)
- // functions that were not inlined. Here we convert the CALL
- // instructions for these functions into INTERFACE_CALL instructions
- // s.t. at runtime, we will look up the Function* on the Type of the
- // 0th argument in the stack and call that directly.
- if (ins.op == CALL) {
- auto node = code->instructions_source()[i];
- if (node->kind() == prim::CallMethod) {
- // NB: replacing instruction
- auto method_name_idx =
- code->constant_table().size() + method_names.size();
- method_names.emplace_back(node->s(attr::name));
- Instruction new_instr{
- INTERFACE_CALL,
- static_cast<int32_t>(method_name_idx),
- static_cast<uint16_t>(node->inputs().size())};
- instructions_copy[i] = new_instr;
- } else {
- TORCH_INTERNAL_ASSERT(
- false, "Unsupported node kind on CALL opcode for mobile");
- }
- } else if (ins.op == RET) {
- auto node = code->instructions_source()[i];
- for (const auto& input : node->inputs()) {
- const auto& input_type = input->type();
- if (input_type->kind() == TypeKind::ListType ||
- input_type->kind() == TypeKind::DictType) {
- for (const TypePtr& element_type : input_type->containedTypes()) {
- TORCH_CHECK(
- element_type->kind() != TypeKind::ClassType,
- "Returining a list or dictionary with pytorch class type ",
- "is not supported in mobile module "
- "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
- "Workaround: instead of using pytorch class as their element type, ",
- "use a combination of list, dictionary, and single types.");
- }
- }
- }
- } else {
- TORCH_CHECK(
- isOpSupportedInMobile(ins.op),
- toString(ins.op),
- " is not supported in mobile module.");
- }
- auto node = code->instructions_source()[i];
- int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
- // Note 1-to-1 correspondence between instructions and debug handles
- op_debug_handles.emplace_back(debug_handle);
- }
+ const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
// instructions
std::vector<IValue> instructions;
- instructions.reserve(instructions_copy.size());
- for (Instruction ins : instructions_copy) {
+ instructions.reserve(mobile_code_ptr->instructions_.size());
+ for (Instruction ins : mobile_code_ptr->instructions_) {
instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
}
// operators
std::vector<IValue> operators;
- auto op_to_specified_args = code->op_to_num_specified_args();
- operators.reserve(opnames.size());
- for (const auto& opname : opnames) {
- auto unique_name = c10::toString(opname);
- // For operator with vararg, adding default arguments would be confusing and
- // is not allowed. For an operator with num_args = -1, it means the number
- // of arguments is not available for this operator, we don't do any backward
- // compatibility adaptation at runtime.
- int num_args = -1;
- auto it = op_to_specified_args.find(unique_name);
- if (it != op_to_specified_args.end()) {
- num_args = it->second;
- }
+ operators.reserve(mobile_code_ptr->op_names_.size());
+ for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
+ const auto& opname = mobile_code_ptr->op_names_[i];
+ const int size = mobile_code_ptr->operator_input_sizes_[i];
if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
} else {
operators.emplace_back(
- to_tuple({opname.name, opname.overload_name, num_args}));
+ to_tuple({opname.name, opname.overload_name, size}));
}
}
- // constants
- //
- // Make a copy of the constants and append the method names
- // that we emitted for the converted INTERFACE_CALL nodes above.
- auto constants = code->constant_table();
- for (auto& method_name : method_names) {
- constants.emplace_back(std::move(method_name));
- }
-
// types
std::vector<IValue> types;
- types.reserve(code->type_table().size());
+ types.reserve(mobile_code_ptr->types_.size());
static const std::string torch_prefix("__torch__");
static const std::string class_prefix("__torch__.torch.classes");
- std::shared_ptr<torch::jit::CompilationUnit> cu =
- module._ivalue()->compilation_unit();
- for (const TypePtr& t : code->type_table()) {
+ for (const TypePtr& t : mobile_code_ptr->types_) {
std::string type_str = t->annotation_str();
if (t->kind() == TypeKind::TupleType) {
TORCH_CHECK(
- cu->get_named_tuple(t->str()),
+ compilation_unit.get_named_tuple(t->str()),
"Can't find definition for the qualified name: ",
t->str(),
"(TypeKind::TupleType) in compilation unit.",
"Please report a bug to PyTorch.");
- auto named_tuple_type = cu->get_named_tuple(t->str());
+ auto named_tuple_type = compilation_unit.get_named_tuple(t->str());
if (named_tuple_type != nullptr) {
std::string named_tuple_str = t->str();
named_tuple_str.append("[NamedTuple, [");
@@ -254,12 +177,12 @@
// since the register location is embedded into the bytecode, pass the
// register size
- auto register_size = static_cast<int>(code->register_size());
+ auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
auto codeTable = Table(
{{"instructions", to_tuple(instructions)},
{"operators", to_tuple(operators)},
- {"constants", to_tuple(constants)},
+ {"constants", to_tuple(mobile_code_ptr->constants_)},
{"types", to_tuple(types)},
{"register_size", register_size}});
@@ -273,14 +196,7 @@
}
return c10::nullopt;
};
- TORCH_CHECK(
- schema.overload_name().empty(), // @TODO: is this check correct?
- "Overloads are not supported in mobile modules.");
- TORCH_CHECK(
- !schema.is_vararg(), "Python *args are not supported in mobile modules.");
- TORCH_CHECK(
- !schema.is_varret(),
- "A variable number of return values is not supported in mobile modules.");
+
auto makeArgTuple = [&](const std::vector<Argument>& args) {
std::vector<IValue> argTables;
for (auto&& arg : args) {
@@ -315,6 +231,17 @@
});
// function tuple
+ std::string qn;
+ if (func.name() == "__setstate__" || func.name() == "__getstate__") {
+ auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
+ TORCH_INTERNAL_ASSERT(
+ classtype, "class is null ", func.qualname().qualifiedName());
+ qn = c10::QualifiedName(
+ type_name_uniquer_.getUniqueName(classtype), func.name())
+ .qualifiedName();
+ } else {
+ qn = func.qualname().qualifiedName();
+ }
auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
c10::optional<IValue> debug_info_vals;
@@ -324,41 +251,27 @@
// debug handles generated by debug_handle_manager
// will correspond to {source_range, inlinedCallStackPtr} which we will
// serialize separately.
- IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
+ IValue module_debug_tuple =
+ c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
auto function_debug_info =
Table({{"function_debug_handles", module_debug_tuple}});
debug_info_vals = to_tuple({qn, function_debug_info});
return std::make_pair(bytecode_vals, debug_info_vals);
}
-void pushFunctionToIValues(
- BytecodeExportSet exportSet,
+void pushMobileFunctionsToIValues(
+ const CompilationUnit& compilation_unit,
+ const mobile::Module& module,
std::vector<c10::IValue>& elements,
std::vector<c10::IValue>& debugInfoElements,
BackendDebugInfoRecorder& recorder,
TypeNameUniquer& uniquer) {
- exportSet.visit(
- [&](const c10::QualifiedName& qn, ExportedFunction& exported) {
- auto tuple = getFunctionTuple(
- exported.mod,
- exported.function,
- std::move(exported.optimizedGraph),
- recorder,
- qn.qualifiedName(),
- uniquer);
- elements.push_back(std::move(tuple.first));
- debugInfoElements.push_back(std::move(tuple.second));
- });
-}
-
-void pushFunctionToIValues(
- BytecodeExportSet exportSet,
- std::vector<c10::IValue>& elements,
- BackendDebugInfoRecorder& recorder,
- TypeNameUniquer& uniquer) {
- std::vector<c10::IValue> debugInfoElements;
- pushFunctionToIValues(
- std::move(exportSet), elements, debugInfoElements, recorder, uniquer);
+ for (const auto& method : module.get_methods()) {
+ auto tuple = getFunctionTuple(
+ compilation_unit, method.function(), recorder, uniquer);
+ elements.push_back(std::move(tuple.first));
+ debugInfoElements.push_back(std::move(tuple.second));
+ }
}
std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
@@ -402,61 +315,6 @@
return ret;
}
-void exportFunction(
- BytecodeExportSet& exportSet,
- const ModuleMethod& method,
- bool toplevel = false) {
- const auto& func = method.function;
- const auto& qn = method.exportName;
- if (exportSet.contains(qn)) {
- if (toplevel) {
- exportSet.update(qn, toplevel);
- }
- return;
- }
- auto graph = func.graph()->copyUnique();
- Inline(*graph);
- auto interfaceCalls = getInterfaceCalls(*graph);
- exportSet.add(
- qn, ExportedFunction{method.module, func, std::move(graph), toplevel});
-
- if (!getMobileInterfaceCallExport()) {
- return;
- }
-
- auto interfaces = getModuleInterfaceExports(method.module, interfaceCalls);
- for (const auto& interface : interfaces) {
- exportFunction(exportSet, interface);
- }
-}
-
-void setstateTuple(
- BytecodeExportSet& exportSet,
- const Module& module,
- const IValue& ivalue,
- TypeNameUniquer& type_name_uniquer_,
- bool toplevel = false) {
- if (!ivalue.isObject())
- return;
- auto obj = ivalue.toObject();
- auto type = obj->type();
- if (checkHasValidSetGetState(type)) {
- Function& setstate = type->getMethod("__setstate__");
- auto qn = type_name_uniquer_.getUniqueName(obj->type()).qualifiedName() +
- "." + setstate.name();
- if (exportSet.contains(qn)) {
- return;
- }
- if (auto f = tryToGraphFunction(setstate)) {
- exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
- }
- } else {
- for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
- setstateTuple(exportSet, module, obj->getSlot(i), type_name_uniquer_);
- }
- }
-}
-
bool isLoweredModule(const Module& m) {
c10::QualifiedName type_name;
if (m.type()->name()) {
@@ -544,24 +402,6 @@
return mobileInterfaceCallExport().load(std::memory_order_relaxed);
}
-BytecodeExportSet moduleMethodsTuple(
- const Module& module,
- TypeNameUniquer& type_name_uniquer_) {
- BytecodeExportSet exportSet;
- auto methods = module.get_methods();
- // top level methods
- for (const auto& method : methods) {
- const auto& f = toGraphFunction(method.function());
- exportFunction(
- exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
- }
-
- // __setstate__ of all components
- setstateTuple(exportSet, module, module._ivalue(), type_name_uniquer_, true);
-
- return exportSet;
-}
-
void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
GetExtraFilesHook() = std::move(hook);
}
@@ -774,9 +614,12 @@
// Always save debug handles
debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
- BytecodeExportSet exportSet = moduleMethodsTuple(module, type_name_uniquer_);
- pushFunctionToIValues(
- std::move(exportSet),
+ mobile::Module mobile_module =
+ jitModuleToMobile(module, getOptionsFromGlobal());
+
+ pushMobileFunctionsToIValues(
+ *module._ivalue()->compilation_unit(),
+ mobile_module,
elements,
debug_info_elements,
debug_info_recorder,
@@ -840,9 +683,9 @@
getBackendDebugInfoMap(module, backend_debug_info_map);
// Now get the debug-handles-to-inlined-cs-ptr-map
// And serialize that in a separate archive
- auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
- debug_handle_cs_ptr_map.insert(
- backend_debug_info_map.begin(), backend_debug_info_map.end());
+ const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
+ BackendDebugInfoMapType debug_handle_cs_ptr_map(
+ debug_info.begin(), debug_info.end());
CallStackDebugInfoPickler cs_debug_info_pickler;
auto cs_data = cs_debug_info_pickler.pickle(
debug_handle_cs_ptr_map, source_range_tags_);
@@ -962,31 +805,13 @@
namespace {
void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
- std::vector<c10::IValue> elements;
- BackendDebugInfoRecorder dummy;
- TypeNameUniquer dummy_uniquer = TypeNameUniquer();
- BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer);
- pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer);
- for (const auto& element : elements) {
- auto table = element.toTupleRef().elements()[1];
- auto row =
- table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
- TORCH_INTERNAL_ASSERT(
- row->elements().at(0).toStringRef() == "operators",
- "Expected operators but found ",
- row->elements().at(0).toStringRef());
- const auto& ops_list = row->elements().at(1).toTupleRef().elements();
- for (const auto& op : ops_list) {
- const auto& op_item = op.toTupleRef().elements();
- TORCH_CHECK(
- op_item.size() >= 2,
- "There should be either two parts (name and overload name), ",
- "or three parts (name, overload name and number of specified args) ",
- "for an operator.");
- auto opname = op_item[0].toString()->string();
- auto overload = op_item[1].toString()->string();
+ mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
+ for (const auto& method : mobile_m.get_methods()) {
+ for (const auto& op : method.function().get_code()->op_names_) {
// NOLINTNEXTLINE(performance-inefficient-string-concatenation)
- opnames.emplace(overload.empty() ? opname : opname + "." + overload);
+ opnames.emplace(
+ op.overload_name.empty() ? op.name
+ : op.name + "." + op.overload_name);
}
}
}