Refactor saving jit::Module to mobile .pt in 2 steps: (#66494)

Summary:
1. is to convert Function -> mobile::Function
2. is to serialize mobile::Function

This also opens opportunity to create mobile::Module without saving/reloading

Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/66494

Reviewed By: zhxchen17

Differential Revision: D32293022

Pulled By: qihqi

fbshipit-source-id: 29b43d47ff86071d5e2f9d6ca4dba4445711ce3d
diff --git a/test/cpp/jit/CMakeLists.txt b/test/cpp/jit/CMakeLists.txt
index b3def38..5c40664 100644
--- a/test/cpp/jit/CMakeLists.txt
+++ b/test/cpp/jit/CMakeLists.txt
@@ -67,6 +67,7 @@
   ${JIT_TEST_ROOT}/test_irparser.cpp
   ${JIT_TEST_ROOT}/test_jit_type.cpp
   ${JIT_TEST_ROOT}/test_lite_interpreter.cpp
+  ${JIT_TEST_ROOT}/test_lite_interpreter_direct.cpp
   ${JIT_TEST_ROOT}/test_lite_trainer.cpp
   ${JIT_TEST_ROOT}/test_memory_dag.cpp
   ${JIT_TEST_ROOT}/test_misc.cpp
diff --git a/test/cpp/jit/test_lite_interpreter_direct.cpp b/test/cpp/jit/test_lite_interpreter_direct.cpp
new file mode 100644
index 0000000..31055e5
--- /dev/null
+++ b/test/cpp/jit/test_lite_interpreter_direct.cpp
@@ -0,0 +1,921 @@
+#include <test/cpp/jit/test_utils.h>
+
+#include <gtest/gtest.h>
+
+#include <c10/core/TensorOptions.h>
+#include <torch/csrc/autograd/generated/variable_factories.h>
+#include <torch/csrc/jit/api/module.h>
+#include <torch/csrc/jit/frontend/resolver.h>
+#include <torch/csrc/jit/mobile/backport.h>
+#include <torch/csrc/jit/mobile/backport_manager.h>
+#include <torch/csrc/jit/mobile/import.h>
+#include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/model_compatibility.h>
+#include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/mobile/parse_bytecode.h>
+#include <torch/csrc/jit/mobile/parse_operators.h>
+#include <torch/csrc/jit/mobile/runtime_compatibility.h>
+#include <torch/csrc/jit/serialization/export.h>
+#include <torch/csrc/jit/serialization/export_bytecode.h>
+#include <torch/csrc/jit/serialization/import.h>
+#include <torch/custom_class.h>
+#include <torch/torch.h>
+
+#include <unordered_set>
+
+// Tests go in torch::jit
+namespace torch {
+namespace jit {
+
+TEST(LiteInterpreterDirectTest, UpsampleNearest2d) {
+  Module m("m");
+  m.define(R"(
+    def forward(self, input: Tensor, scale:float):
+      return torch.upsample_nearest2d(input, [1, 1], float(scale), float(scale))
+  )");
+
+  std::vector<IValue> inputs;
+  inputs.emplace_back(torch::rand({1, 3, 128, 128}));
+  inputs.emplace_back(at::Scalar(2.0));
+  auto ref = m.forward(inputs);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  res = bc.forward(inputs);
+
+  auto resd = res.toTensor();
+  auto refd = ref.toTensor();
+  ASSERT_TRUE(resd.equal(refd));
+}
+
+TEST(LiteInterpreterDirectTest, CheckAttrAccess) {
+  Module m("m");
+  m.register_attribute("mobile_optimized", BoolType::get(), true);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  bool mobile_optimized = bc.attr("mobile_optimized", false).toBool();
+
+  AT_ASSERT(mobile_optimized);
+  m.setattr("mobile_optimized", false);
+  bc = jitModuleToMobile(m, options);
+  mobile_optimized = bc.attr("mobile_optimized", false).toBool();
+  AT_ASSERT(!mobile_optimized);
+}
+
+TEST(
+    LiteInterpreterDirectTest,
+    MethodInvocation) { // NOLINT (use =delete in gtest)
+  const std::vector<std::string> test_programs{
+      // test invoking a method with default parameter
+      R"(
+      def test_func(self, x, b : int = 4):
+        return self.foo + x + b
+      )",
+      // inner method call with default parameter (gets inlined)
+      R"(
+      def add_with_default_arg(self, x, b : int = 4):
+        return self.foo + x + b
+      def test_func(self, x):
+        return self.add_with_default_arg(x)  # invoke method w/ default arg
+      )",
+      // simple method call
+      R"(
+      def test_func(self, x):
+        b = 4
+        return self.foo + x + b
+      )",
+  };
+  for (const auto& test_program : test_programs) {
+    Module m("m");
+    m.register_parameter("foo", torch::ones({}), false);
+    m.define(test_program);
+
+    const int fortyTwo = 42; // (keep linter happy)
+    auto minput = fortyTwo * torch::ones({});
+    auto ref = m.run_method("test_func", minput);
+
+    CompilationOptions options;
+    mobile::Module bc = jitModuleToMobile(m, options);
+    const auto& test_func = bc.get_method("test_func");
+    std::cerr << "hello " << std::endl;
+    IValue res;
+    for (int i = 0; i < 3; ++i) {
+      res = test_func({minput});
+    }
+    std::cerr << "hello 3" << std::endl;
+
+    auto resd = res.toTensor().item<float>();
+    auto refd = ref.toTensor().item<float>();
+    AT_ASSERT(resd == refd);
+  }
+}
+
+TEST(LiteInterpreterDirectTest, Conv) {
+  auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
+  if (s && strcmp(s, "1") == 0)
+    return;
+
+  std::vector<torch::jit::IValue> inputs;
+
+  Module m("m");
+  m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
+  m.register_parameter("bias", torch::ones({20}), false);
+  m.define(R"(
+    def forward(self, input):
+      return torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
+  )");
+
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
+  inputs.push_back(torch::ones({1, 1, 28, 28}));
+
+  auto outputref = m.forward(inputs).toTensor();
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    res = bc.get_method("forward")(inputs);
+  }
+  auto output = res.toTensor();
+  AT_ASSERT(outputref.dim() == output.dim());
+  AT_ASSERT(
+      outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
+}
+
+TEST(LiteInterpreterDirectTest, Inline) {
+  Module m("m");
+  m.define(R"JIT(
+  def foo1(self, x):
+      return x + 1
+
+  def foo2(self, x):
+      return self.foo1(x) + 2
+
+  def foo3(self, x):
+      return self.foo2(x) + 3
+  )JIT");
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  std::vector<torch::jit::IValue> inputs({torch::ones({})});
+  auto output = bc.get_method("foo3")(inputs);
+  AT_ASSERT(output.toTensor().item<float>() == 7.0);
+}
+
+TEST(LiteInterpreterDirectTest, Tuple) {
+  Module m("m");
+  m.define(R"JIT(
+  def foo(self, x):
+      return (1, 2, x + 3)
+
+  def forward(self, x):
+      tuple = self.foo(x)
+      return tuple
+  )JIT");
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  std::vector<torch::jit::IValue> inputs({torch::ones({})});
+  auto output = bc.get_method("forward")(inputs);
+  AT_ASSERT(output.toTupleRef().elements()[1].toInt() == 2);
+}
+
+TEST(LiteInterpreterDirectTest, Dict) {
+  Module m("m");
+  m.define(R"JIT(
+  def foo(self, x):
+      return {"result": x + 1}
+
+  def forward(self, x):
+      d = self.foo(x)
+      return d
+  )JIT");
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  std::vector<torch::jit::IValue> inputs({torch::ones({})});
+  auto output = bc.get_method("forward")(inputs);
+  AT_ASSERT(output.toGenericDict().at("result").toTensor().item().toInt() == 2);
+}
+
+TEST(LiteInterpreterDirectTest, Prim) {
+  Module m("m");
+  m.define(R"JIT(
+        def forward(self, x):
+            return int(x)
+  )JIT");
+
+  std::vector<IValue> inputs;
+  auto minput = 3.5 * torch::ones({});
+  inputs.emplace_back(minput);
+  auto ref = m.run_method("forward", minput);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+    auto bcinputs = inputs;
+    res = bc.get_method("forward")(bcinputs);
+  }
+
+  auto resi = res.toInt();
+  auto refi = ref.toInt();
+  AT_ASSERT(resi == refi);
+}
+
+TEST(LiteInterpreterDirectTest, PrimScalar) {
+  Module m("m");
+  m.define(R"JIT(
+        def forward(self, x):
+            return int(x.item())
+  )JIT");
+
+  std::vector<IValue> inputs;
+  auto minput = 3.5 * torch::ones({});
+  inputs.emplace_back(minput);
+  auto ref = m.run_method("forward", minput);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+    auto bcinputs = inputs;
+    res = bc.get_method("forward")(bcinputs);
+  }
+
+  auto resi = res.toInt();
+  auto refi = ref.toInt();
+  AT_ASSERT(resi == refi);
+}
+
+TEST(LiteInterpreterDirectTest, WrongMethodName) {
+  Module m("m");
+  m.register_parameter("foo", torch::ones({}), false);
+  m.define(R"(
+    def add(self, x):
+      b = 4
+      return self.foo + x + b
+  )");
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  std::vector<IValue> inputs;
+  auto minput = 5 * torch::ones({});
+  inputs.emplace_back(minput);
+  ASSERT_THROWS_WITH_MESSAGE(
+      bc.get_method("forward")(inputs), "is not defined");
+}
+
+TEST(LiteInterpreterDirectTest, SetState) {
+  Module m("m");
+  m.register_parameter("foo", torch::ones({}), false);
+  m.define(R"(
+    def __getstate__(self):
+      return self.foo
+    def __setstate__(self, a):
+      self.foo = a
+    def forward(self, x):
+      b = 4
+      return self.foo + x + b
+  )");
+
+  std::vector<IValue> inputs;
+  auto minput = 5 * torch::ones({});
+  inputs.emplace_back(minput);
+
+  std::stringstream ms;
+  m.save(ms);
+  auto loaded_m = load(ms);
+  auto ref = loaded_m.run_method("forward", minput);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+    auto bcinputs = inputs;
+    res = bc.get_method("forward")(bcinputs);
+  }
+
+  auto resd = res.toTensor().item<float>();
+  auto refd = ref.toTensor().item<float>();
+  AT_ASSERT(resd == refd);
+}
+
+class TorchBindLiteInterpreterDirectTestStruct
+    : public torch::jit::CustomClassHolder {
+ public:
+  std::string get(at::Tensor t) {
+    std::stringstream ss;
+    ss << "Hello! Your tensor has ";
+    ss << t.numel();
+    ss << " elements!";
+    return ss.str();
+  }
+};
+
+namespace {
+struct ClassNamespaceValue : public SugaredValue {
+  explicit ClassNamespaceValue(c10::QualifiedName name)
+      : basename_(std::move(name)) {}
+
+  std::shared_ptr<SugaredValue> attr(
+      const SourceRange&,
+      GraphFunction&,
+      const std::string& name) override {
+    const auto fullName = c10::QualifiedName(basename_, name);
+
+    // Check to see if it is a custom class.
+    if (auto custom_class = getCustomClass(fullName.qualifiedName())) {
+      return std::make_shared<ClassValue>(custom_class);
+    }
+
+    // If it's not a custom class, assume it's another namespace
+    // NOLINTNEXTLINE(performance-move-const-arg)
+    return std::make_shared<ClassNamespaceValue>(fullName);
+  }
+
+  std::string kind() const override {
+    return "Class Namespace";
+  }
+
+ private:
+  c10::QualifiedName basename_;
+};
+
+struct TestModuleResolver : public Resolver {
+  std::shared_ptr<SugaredValue> resolveValue(
+      const std::string& name,
+      GraphFunction&,
+      const SourceRange&) override {
+    if (name == "torch") {
+      return std::make_shared<BuiltinModule>("aten");
+    } else if (name == "__torch__") {
+      return std::make_shared<ClassNamespaceValue>(c10::QualifiedName(name));
+    }
+
+    return nullptr;
+  }
+
+  TypePtr resolveType(const std::string&, const SourceRange&) override {
+    return nullptr;
+  }
+};
+} // namespace
+
+TEST(LiteInterpreterDirectTest, BuiltinFunction) {
+  script::Module m("m");
+  auto custom_class_obj =
+      make_custom_class<TorchBindLiteInterpreterDirectTestStruct>();
+  m.register_attribute("my_obj", custom_class_obj.type(), custom_class_obj);
+  m.define(R"(
+    def forward(self, x) -> str:
+      return self.my_obj.get(x)
+  )");
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  auto res =
+      bc.get_method("forward")(std::vector<IValue>{torch::zeros({3, 4})});
+  // NOLINTNEXTLINE(performance-unnecessary-copy-initialization)
+  auto str = res.toStringRef();
+  std::string expected = "Hello! Your tensor has 12 elements!";
+  AT_ASSERT(str == expected);
+}
+
+#if !defined FB_XPLAT_BUILD
+TEST(LiteInterpreterDirectTest, GetRuntimeByteCodeVersion) {
+  auto runtime_bytecode_version = _get_runtime_bytecode_version();
+  AT_ASSERT(
+      runtime_bytecode_version ==
+      caffe2::serialize::kMaxSupportedBytecodeVersion);
+}
+
+TEST(LiteInterpreterDirectTest, GetRuntimeOperatorsVersion) {
+  auto runtime_operators_version = _get_runtime_operators_min_max_versions();
+  AT_ASSERT(
+      runtime_operators_version.first ==
+          caffe2::serialize::kMinSupportedFileFormatVersion &&
+      runtime_operators_version.second ==
+          caffe2::serialize::kMaxSupportedFileFormatVersion);
+}
+
+/**
+ * The test below is disarmed for FB internal xplat builds since
+ * BUCK requires us to pass in the script_module_v4.ptl file in
+ * as a resource dependency of the build rule for this file, and
+ * we would need to access it via the C++ Resources API instead
+ * of directly reading from disk (which is what the open source
+ * build/run does).
+ */
+TEST(LiteInterpreterDirectTest, GetByteCodeVersion) {
+  std::string filePath(__FILE__);
+  auto test_model_file_v4 =
+      filePath.substr(0, filePath.find_last_of("/\\") + 1);
+  test_model_file_v4.append("script_module_v4.ptl");
+
+  auto version_v4 = _get_model_bytecode_version(test_model_file_v4);
+  AT_ASSERT(version_v4 == 4);
+}
+
+#endif // !defined(FB_XPLAT_BUILD)
+
+TEST(LiteInterpreterDirectTest, GetRuntimeOpsAndInfo) {
+  auto runtime_ops = _get_runtime_ops_and_info();
+  // Ballpark estimate of the minimal number of ops; just used to
+  // verify API returns a reasonably large number.
+  AT_ASSERT(runtime_ops.size() > 2900);
+}
+
+TEST(LiteInterpreterDirectTest, Eval) {
+  std::vector<torch::jit::IValue> inputs;
+
+  Module m("m");
+  m.define(R"(
+    def __init__(self, x):
+      self.training = True
+
+    def forward(self, input):
+      return torch.dropout(input, 1.0, self.training)
+  )");
+
+  // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers,modernize-use-emplace)
+  inputs.push_back(torch::ones({1, 1, 28, 28}));
+  m.eval();
+  auto outputref = m.forward(inputs).toTensor();
+
+  // save m in training mode to make sure that mobile eval() will correctly
+  // change back to eval mode
+  m.train();
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  bc.eval();
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    res = bc.get_method("forward")(inputs);
+  }
+  auto output = res.toTensor();
+  AT_ASSERT(outputref.dim() == output.dim());
+  AT_ASSERT(
+      outputref[0][0][0][0].item<int>() == output[0][0][0][0].item<int>());
+}
+
+TEST(LiteInterpreterDirectTest, FindWrongMethodName) {
+  Module m("m");
+  m.register_parameter("foo", torch::ones({}), false);
+  m.define(R"(
+    def add(self, x):
+      b = 4
+      return self.foo + x + b
+  )");
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  ASSERT_TRUE(bc.find_method("forward") == c10::nullopt);
+}
+
+TEST(LiteInterpreterDirectTest, FindAndRunMethod) {
+  Module m("m");
+  m.register_parameter("foo", torch::ones({}), false);
+  m.define(R"(
+    def add_it(self, x):
+      b = 4
+      return self.foo + x + b
+  )");
+
+  std::vector<IValue> inputs;
+  auto minput = 5 * torch::ones({});
+  inputs.emplace_back(minput);
+  auto ref = m.get_method("add_it")(inputs);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    auto bcinputs = inputs;
+    auto method = bc.find_method("add_it");
+    AT_ASSERT(method != c10::nullopt);
+    res = (*method)(std::move(bcinputs));
+  }
+
+  auto resd = res.toTensor().item<float>();
+  auto refd = ref.toTensor().item<float>();
+  AT_ASSERT(resd == refd);
+}
+
+TEST(LiteInterpreterDirectTest, RunMethodVariadic) {
+  Module m("m");
+  m.register_parameter("foo", torch::ones({}), false);
+  m.define(R"(
+    def add_three(self, x, y):
+      return self.foo + x + y
+  )");
+
+  std::vector<IValue> inputs;
+  auto inputx = 5 * torch::ones({});
+  auto inputy = 4 * torch::ones({});
+  auto ref = m.run_method("add_three", inputx, inputy);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res = bc.run_method("add_three", inputx, inputy);
+
+  auto resd = res.toTensor().item<float>();
+  auto refd = ref.toTensor().item<float>();
+  AT_ASSERT(resd == refd);
+}
+
+TEST(LiteInterpreterDirectTest, DuplicateSetState) {
+  Module m("M");
+  m.register_parameter("foo", torch::ones({}), false);
+  m.define(R"(
+    def __getstate__(self):
+      return self.foo + self.foo
+    def __setstate__(self, a):
+      self.foo = a
+    def forward(self, x):
+      b = 4
+      return self.foo + x + b
+  )");
+
+  Module b("B");
+  b.register_module("M0", m);
+  b.register_module("M1", m);
+  b.define(R"(
+    def forward(self, x):
+      return self.M0.forward(x) + self.M1.forward(x)
+  )");
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  const auto methods = bc.get_methods();
+  const size_t expected_n = 3;
+  ASSERT_EQ(methods.size(), expected_n);
+}
+
+TEST(LiteInterpreterDirectTest, OpNameExportFetchRootOperators) {
+  torch::jit::Module m("m");
+  m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
+  m.register_parameter("bias", torch::ones({20}), false);
+  m.define(R"(
+    def forward(self, input):
+      x1 = torch.zeros(2, 2)
+      x2 = torch.empty_like(torch.empty(2, 2))
+      x3 = torch._convolution(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1, False, False, True, True)
+      return (x1, x2, x3)
+  )");
+  m.eval();
+
+  CompilationOptions options;
+  mobile::Module ptl_model = jitModuleToMobile(m, options);
+  std::set<std::string> operator_names =
+      torch::jit::mobile::_export_operator_list(ptl_model);
+  std::set<std::string> expected_operator_names = {
+      "aten::_convolution",
+      "aten::empty.memory_format",
+      "aten::empty_like",
+      "aten::zeros",
+  };
+  EXPECT_EQ(operator_names, expected_operator_names)
+      << "Expected the root operator lists to be the same";
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsConv) {
+  auto s = std::getenv("PYTORCH_TEST_WITH_TSAN");
+  if (s && strcmp(s, "1") == 0)
+    return;
+
+  std::vector<torch::jit::IValue> inputs;
+
+  Module m("m");
+  m.register_parameter("weight", torch::ones({20, 1, 5, 5}), false);
+  m.register_parameter("bias", torch::ones({20}), false);
+  m.define(R"(
+    def forward(self, input):
+      return torch.conv2d(input, self.weight, self.bias, [1, 1], [0, 0], [1, 1], 1)
+  )");
+
+  inputs.emplace_back(torch::ones({1, 1, 28, 28}));
+
+  auto outputref = m.forward(inputs).toTensor();
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  for (int i = 0; i < 1; ++i) {
+    res = bc.get_method("forward")(inputs);
+  }
+  auto output = res.toTensor();
+  AT_ASSERT(outputref.dim() == output.dim());
+  AT_ASSERT(output.equal(outputref));
+}
+
+namespace {
+void testLiteModuleCompareResultTensors(
+    Module& m,
+    const std::vector<torch::jit::IValue>& inputs,
+    const std::string& method_name = "forward") {
+  auto outputref = m.get_method(method_name)(inputs).toTensor();
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  IValue res;
+  for (int i = 0; i < 3; ++i) {
+    res = bc.get_method(method_name)(inputs);
+  }
+  auto output = res.toTensor();
+  AT_ASSERT(outputref.dim() == output.dim());
+  AT_ASSERT(output.equal(outputref));
+}
+
+void testDefaultArgsPinv2(int num_args) {
+  Module m("m");
+  if (num_args == 1) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input)
+    )");
+  } else if (num_args == 2) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, 1e-5)
+    )");
+  } else if (num_args == 3) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, 1e-5, True)
+    )");
+  }
+
+  std::vector<torch::jit::IValue> inputs;
+  const int N = 28;
+  auto input = torch::range(1, N * N, 1);
+  input[0] = 1; // a more stable matrix
+  input = input.view({N, N});
+  inputs.emplace_back(input);
+  testLiteModuleCompareResultTensors(m, inputs);
+}
+} // namespace
+
+#if !defined FB_XPLAT_BUILD
+TEST(LiteInterpreterDirectTest, DefaultArgsPinv) {
+  // Test with different number of specified arguments.
+  // Arguments not specified take default value.
+  for (int num_args = 1; num_args <= 3; ++num_args) {
+    testDefaultArgsPinv2(num_args);
+  }
+
+  //  bytecode with one specified argument:
+  //  (6,
+  //      ('__torch__.m.forward',
+  //          (('instructions',
+  //              (('STOREN', 1, 2),
+  //                  ('DROPR', 1, 0),
+  //                  ('MOVE', 2, 0),
+  //                  ('OP', 0, 0),
+  //                  ('RET', 0, 0))),
+  //              ('operators', (('aten::linalg_pinv', '', 1),)),
+  //              ('constants', (False, 1e-15)), # default constants are not
+  //              used
+  //              ('types', ()),
+  //              ('register_size', 2)),
+  //          (('arguments',
+  //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
+  //              None)),
+  //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
+  //                  None)))),
+  //              ('returns',
+  //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
+  //                  None)),)))))
+
+  //  bytecode with 2 specified argument:
+  //  (6,
+  //      ('__torch__.m.forward',
+  //          (('instructions',
+  //              (('STOREN', 1, 2),
+  //                  ('DROPR', 1, 0),
+  //                  ('MOVE', 2, 0),
+  //                  ('LOADC', 1, 0), # added LOADC for specified argument
+  //                  ('OP', 0, 0),
+  //                  ('RET', 0, 0))),
+  //              ('operators', (('aten::linalg_pinv', '', 2),)),
+  //              ('constants', (False, 1e-05)), # updated constant table
+  //              ('types', ()),
+  //              ('register_size', 2)),
+  //          (('arguments',
+  //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
+  //              None)),
+  //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
+  //                  None)))),
+  //              ('returns',
+  //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
+  //                  None)),)))))
+
+  //  bytecode with 3 specified arguments:
+  //  (6,
+  //      ('__torch__.m.forward',
+  //          (('instructions',
+  //              (('STOREN', 1, 2),
+  //                  ('DROPR', 1, 0),
+  //                  ('MOVE', 2, 0),
+  //                  ('LOADC', 1, 0),
+  //                  ('LOADC', 0, 0),
+  //                  ('OP', 0, 0),
+  //                  ('RET', 0, 0))),
+  //              ('operators', (('aten::linalg_pinv', '', 3),)),
+  //              ('constants', (True, 1e-05)),
+  //              ('types', ()),
+  //              ('register_size', 2)),
+  //          (('arguments',
+  //              ((('name', 'self'), ('type', '__torch__.m'), ('default_value',
+  //              None)),
+  //                  (('name', 'input'), ('type', 'Tensor'), ('default_value',
+  //                  None)))),
+  //              ('returns',
+  //                  ((('name', ''), ('type', 'Tensor'), ('default_value',
+  //                  None)),)))))
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsTensorinvSpecifyDefault) {
+  // The second argument is specified, but the value is the same as the default
+  // value. It's treated as "not specified" since the value can be fetched from
+  // schema.
+  Module m("m");
+  m.define(R"(
+    def forward(self, input):
+      return torch.linalg_tensorinv(input, 2)
+  )");
+  torch::jit::MobileCode code(m.get_method("forward").graph(), "forward");
+  auto arg_nums = code.op_to_num_specified_args();
+  ASSERT_EQ(arg_nums.size(), 1);
+  ASSERT_EQ(arg_nums["aten::linalg_tensorinv"], 1);
+  std::vector<torch::jit::IValue> inputs;
+  const int N = 4;
+  auto input = torch::rand({N, N, N, N});
+  inputs.emplace_back(input);
+  testLiteModuleCompareResultTensors(m, inputs);
+}
+
+void testDefaultArgsPinvWithOutArg2(int num_args) {
+  Module m("m");
+  if (num_args == 1) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, out=input)
+    )");
+  } else if (num_args == 2) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, 1e-5, out=input)
+    )");
+  } else if (num_args == 3) {
+    m.define(R"(
+      def forward(self, input):
+        return torch.linalg_pinv(input, 1e-5, True, out=input)
+    )");
+  }
+
+  const int N = 28;
+  auto input = torch::range(1, N * N, 1);
+  input[0] = 10000; // a more stable matrix
+  input = input.view({N, N});
+  auto ref = m.run_method("forward", input);
+  TORCH_CHECK(!input.equal(torch::range(1, N * N, 1)));
+  TORCH_CHECK(input.equal(ref.toTensor()));
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsPinvWithOutArg) {
+  // Test with different number of specified arguments + out arg.
+  // Arguments not specified take default value.
+  for (int num_args = 1; num_args <= 3; ++num_args) {
+    testDefaultArgsPinvWithOutArg2(num_args);
+  }
+}
+
+TEST(LiteInterpreterDirectTest, DefaultArgsWithOutArg) {
+  Module m("m");
+  m.define(R"(
+    def forward(self, x, h):
+      torch.add(x, h, out=x)
+  )");
+
+  std::vector<IValue> inputs;
+  auto input_x = 2 * torch::ones({});
+  auto input_h = torch::ones({});
+  auto ref = m.run_method("forward", input_x, input_h);
+
+  CompilationOptions options;
+  mobile::Module bc = jitModuleToMobile(m, options);
+  bc.run_method("forward", input_x, input_h);
+  AT_ASSERT(input_x.equal(4 * torch::ones({})));
+}
+
+TEST(LiteInterpreterDirectTest, TestExceptionStackWithTwoLevelModuleHierarchy) {
+  Module a("A");
+  a.define(R"(
+    def bar(self, x, y):
+      return x + y
+  )");
+  Module b("B");
+  b.register_module("A0", a);
+  b.define(R"(
+    def foo(self, x, y):
+      return self.A0.bar(x, y) + 2
+  )");
+  Module c("C");
+  c.register_module("B0", b);
+  c.define(R"(
+    def forward(self, x, y):
+      return self.B0.foo(x, y) + 3
+  )");
+
+  std::vector<IValue> inputs;
+  inputs.emplace_back(torch::rand({2, 4}));
+  inputs.emplace_back(torch::rand({13, 9}));
+
+  CompilationOptions options;
+  auto lite_m = jitModuleToMobile(c, options);
+  std::string error_pattern = R"(
+  Module hierarchy:top(C)::<unknown>.B0(B)::foo.A0(A)::bar.aten::add
+Traceback of TorchScript (most recent call last):
+  File "<string>", line 3, in <unknown>
+
+    def forward(self, x, y):
+      return self.B0.foo(x, y) + 3
+             ~~~~~~~~~~~ <--- HERE
+
+  File "<string>", line 3, in foo
+
+    def foo(self, x, y):
+      return self.A0.bar(x, y) + 2
+             ~~~~~~~~~~~ <--- HERE
+
+  File "<string>", line 3, in bar
+
+    def bar(self, x, y):
+      return x + y
+             ~~~~~ <--- HERE
+  )";
+  ASSERT_THROWS_WITH_MESSAGE(lite_m.forward(inputs), error_pattern);
+}
+#endif // !defined(FB_XPLAT_BUILD)
+
+namespace {
+static auto reg =
+    torch::class_<TorchBindLiteInterpreterDirectTestStruct>(
+        "_TorchScriptTesting",
+        "_LiteInterpreterDirectTest")
+        .def(torch::init<>())
+        .def("get", &TorchBindLiteInterpreterDirectTestStruct::get)
+        .def_pickle(
+            // __getattr__
+            [](const c10::intrusive_ptr<
+                TorchBindLiteInterpreterDirectTestStruct>&) -> int64_t {
+              return 0;
+            },
+            // __setattr__
+            [](int64_t) {
+              return c10::make_intrusive<
+                  TorchBindLiteInterpreterDirectTestStruct>();
+            });
+
+} // namespace
+
+TEST(LiteInterpreterDirectTest, OperatorCacheDifferentiatesDefaultArgs) {
+  // Create 3 methods:
+  //
+  // 1. forward() returns a tensor with dtype=torch.int64 (4)
+  // 2. forward2() returns a tensor with dtype=torch.float32 (6)
+  // 3. forward3() returns a tensor with dtype=torch.float32 but
+  //    the dtype is inferred by the input tensor's dtype
+  //
+  // If caching works correctly, then the result from the full-jit
+  // module and the lite module will be the same. Otherwise, it
+  // will be different if we don't correctly ignore the cache
+  // entry for an operator that has a different number of
+  // arguments.
+  Module m("m");
+  m.define(R"(
+    def forward(self):
+      ret1 = torch.new_empty(torch.zeros(10), [10], dtype=4)
+      return ret1.fill_(25)
+  )");
+  m.define(R"(
+    def forward2(self):
+      ret1 = torch.new_empty(torch.zeros(10), [10], dtype=6)
+      return ret1.fill_(32.0)
+  )");
+  m.define(R"(
+    def forward3(self):
+      ret1 = torch.new_empty(torch.zeros(10), [10])
+      return ret1.fill_(12.0)
+  )");
+
+  std::vector<torch::jit::IValue> inputs;
+  testLiteModuleCompareResultTensors(m, inputs, "forward");
+  testLiteModuleCompareResultTensors(m, inputs, "forward2");
+  testLiteModuleCompareResultTensors(m, inputs, "forward3");
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/api/function_impl.h b/torch/csrc/jit/api/function_impl.h
index 5fd20a9..27e61b8 100644
--- a/torch/csrc/jit/api/function_impl.h
+++ b/torch/csrc/jit/api/function_impl.h
@@ -24,6 +24,10 @@
 
   void run(Stack& stack) override;
 
+  std::function<void(GraphFunction&)> function_creator() const {
+    return function_creator_;
+  }
+
   c10::intrusive_ptr<c10::ivalue::Future> runAsync(
       Stack& stack,
       TaskLauncher taskLauncher = at::launch) override;
diff --git a/torch/csrc/jit/mobile/code.h b/torch/csrc/jit/mobile/code.h
index 5ea6617..c91a76b 100644
--- a/torch/csrc/jit/mobile/code.h
+++ b/torch/csrc/jit/mobile/code.h
@@ -20,6 +20,7 @@
   std::vector<Instruction> instructions_;
   std::vector<DebugHandle> debug_handles_;
   std::vector<c10::OperatorName> op_names_;
+  std::vector<int> operator_input_sizes_;
   std::vector<std::function<void(Stack&)>> operators_;
   std::vector<c10::IValue> constants_;
   std::vector<c10::TypePtr> types_;
diff --git a/torch/csrc/jit/mobile/debug_info.h b/torch/csrc/jit/mobile/debug_info.h
index 444573c..bbeb4ed 100644
--- a/torch/csrc/jit/mobile/debug_info.h
+++ b/torch/csrc/jit/mobile/debug_info.h
@@ -23,6 +23,10 @@
   MobileDebugTable(
       std::unique_ptr<caffe2::serialize::PyTorchStreamReader>& reader,
       const std::shared_ptr<CompilationUnit>& cu);
+
+  template <typename It>
+  MobileDebugTable(It begin, It end) : callstack_ptr_map_(begin, end) {}
+
   std::string getSourceDebugString(
       const int64_t debug_handle,
       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
@@ -36,6 +40,11 @@
       const std::vector<int64_t>& debug_handles,
       const std::string& top_module_type_name = "ModuleTypeUnknown") const;
 
+  const ska::flat_hash_map<int64_t, DebugInfoTuple>& getCallStackPtrMap()
+      const {
+    return callstack_ptr_map_;
+  }
+
  private:
   std::pair<std::string, std::string> getSourceDebugModuleHierarchyInfo(
       const std::vector<int64_t>& debug_handles,
diff --git a/torch/csrc/jit/mobile/function.cpp b/torch/csrc/jit/mobile/function.cpp
index bc1bd53..8af058e 100644
--- a/torch/csrc/jit/mobile/function.cpp
+++ b/torch/csrc/jit/mobile/function.cpp
@@ -13,6 +13,14 @@
 Function::Function(c10::QualifiedName name)
     : name_(std::move(name)), code_(std::make_shared<Code>()) {}
 
+Function::Function(
+    c10::QualifiedName name,
+    std::shared_ptr<Code> code,
+    at::optional<c10::FunctionSchema> schema)
+    : name_(std::move(name)),
+      code_(std::move(code)),
+      schema_(std::move(schema)) {}
+
 const c10::QualifiedName& Function::qualname() const {
   return name_;
 }
@@ -43,89 +51,11 @@
   // Keep the original opname in code_
   code_->op_names_.emplace_back(name, overload_name);
   const auto& opname = code_->op_names_.back();
-  const auto full_name = c10::toString(opname);
-
-  std::function<void(Stack&)> fn;
-
-  const std::vector<c10::Argument>* pArgs = nullptr;
-  bool promoted_op = mobile::hasPrimOpsFn(full_name);
-  if (promoted_op) {
-    fn = mobile::getPrimOpsFn(full_name);
-  } else {
-    std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
-    if (jit_op) {
-      fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
-      pArgs = &jit_op->schema().arguments();
-    } else {
-      auto op = c10::Dispatcher::singleton().findSchema(opname);
-      if (op.has_value()) {
-        fn = [op](Stack& stack) { op->callBoxed(&stack); };
-        if (op->hasSchema()) {
-          pArgs = &op->schema().arguments();
-        } else {
-          TORCH_CHECK(false, "arguments are missing for operator ", opname);
-        }
-      } else {
-        return false;
-      }
-    }
+  auto func = makeOperatorFunction(opname, num_specified_args, model_version);
+  if (!func.has_value()) {
+    return false;
   }
-
-  if (!promoted_op) {
-    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
-    const auto& args = *pArgs;
-    if (model_version == 0x3LL && opname.name == "aten::_convolution" &&
-        opname.overload_name.empty()) {
-      // Since byte-code versions 0x4L, convolution has an additional
-      // default-value argument (allow_tf32=True, see
-      // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
-      // backward compatibility with models of byte-code version <= 0x3L, where
-      // this bool argument does not yet exist.
-      fn = [fn](Stack& stack) {
-        stack.push_back(true);
-        fn(stack);
-      };
-    } else {
-      // num_specified_args >= 0 indicates number of arguments are available
-      // from model. We can use it to handle backward compatibility.
-      if (num_specified_args &&
-          num_specified_args.value() < static_cast<int64_t>(args.size())) {
-        fn = [fn, num_specified_args, &args](Stack& stack) {
-          std::vector<IValue> out_args;
-          // The following logic pops and temporarily stores all out arguments
-          // from the stack (which can be 0 or more, and always appended to the
-          // schema), in order to push the necessary default values. Finally,
-          // the out arguments are pushed back into the stack.
-          for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
-            out_args.push_back(stack.back());
-            stack.pop_back();
-          }
-          size_t start_index = num_specified_args.value() - out_args.size();
-          TORCH_CHECK(
-              start_index >= 0,
-              "The number of output arguments is: ",
-              out_args.size(),
-              ", which is more then the number of specified arguments: ",
-              num_specified_args.value());
-          for (size_t i = start_index; i < (args.size() - out_args.size());
-               ++i) {
-            TORCH_CHECK(
-                args[i].default_value().has_value(),
-                "Error happened at preparing for default values for the argument. The ",
-                i,
-                "th argument ",
-                args[i].name(),
-                " does not have a specified value or default value. ");
-
-            stack.push_back(args[i].default_value());
-          }
-          stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
-          fn(stack);
-        };
-      }
-    }
-  }
-  code_->operators_.emplace_back(fn);
+  code_->operators_.emplace_back(*func);
   return true;
 }
 
@@ -197,6 +127,93 @@
   return getInterpretersExceptionDebugHandles();
 }
 
+c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
+    c10::OperatorName opname,
+    c10::optional<int> num_specified_args,
+    int64_t model_version) {
+  std::function<void(Stack&)> fn;
+  const auto full_name = c10::toString(opname);
+  const std::vector<c10::Argument>* pArgs = nullptr;
+  bool promoted_op = mobile::hasPrimOpsFn(full_name);
+  if (promoted_op) {
+    fn = mobile::getPrimOpsFn(full_name);
+  } else {
+    std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
+    if (jit_op) {
+      fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
+      pArgs = &jit_op->schema().arguments();
+    } else {
+      auto op = c10::Dispatcher::singleton().findSchema(opname);
+      if (op.has_value()) {
+        fn = [op](Stack& stack) { op->callBoxed(&stack); };
+        if (op->hasSchema()) {
+          pArgs = &op->schema().arguments();
+        } else {
+          TORCH_CHECK(false, "arguments are missing for operator ", opname);
+        }
+      } else {
+        return c10::nullopt;
+      }
+    }
+  }
+
+  if (!promoted_op) {
+    TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
+    const auto& args = *pArgs;
+    if (model_version == 0x3LL && opname.name == "aten::_convolution" &&
+        opname.overload_name.empty()) {
+      // Since byte-code versions 0x4L, convolution has an additional
+      // default-value argument (allow_tf32=True, see
+      // https://github.com/pytorch/pytorch/pull/40737). This wrapper handles
+      // backward compatibility with models of byte-code version <= 0x3L, where
+      // this bool argument does not yet exist.
+      fn = [fn](Stack& stack) {
+        stack.push_back(true);
+        fn(stack);
+      };
+    } else {
+      // num_specified_args >= 0 indicates number of arguments are available
+      // from model. We can use it to handle backward compatibility.
+      if (num_specified_args &&
+          num_specified_args.value() < static_cast<int64_t>(args.size())) {
+        fn = [fn, num_specified_args, &args](Stack& stack) {
+          std::vector<IValue> out_args;
+          // The following logic pops and temporarily stores all out arguments
+          // from the stack (which can be 0 or more, and always appended to the
+          // schema), in order to push the necessary default values. Finally,
+          // the out arguments are pushed back into the stack.
+          for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
+            out_args.push_back(stack.back());
+            stack.pop_back();
+          }
+          size_t start_index = num_specified_args.value() - out_args.size();
+          TORCH_CHECK(
+              start_index >= 0,
+              "The number of output arguments is: ",
+              out_args.size(),
+              ", which is more then the number of specified arguments: ",
+              num_specified_args.value());
+          for (size_t i = start_index; i < (args.size() - out_args.size());
+               ++i) {
+            TORCH_CHECK(
+                args[i].default_value().has_value(),
+                "Error happened at preparing for default values for the argument. The ",
+                i,
+                "th argument ",
+                args[i].name(),
+                " does not have a specified value or default value. ");
+
+            stack.push_back(args[i].default_value());
+          }
+          stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
+          fn(stack);
+        };
+      }
+    }
+  }
+  return fn;
+}
+
 } // namespace mobile
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/mobile/function.h b/torch/csrc/jit/mobile/function.h
index 5e20515..c29638a 100644
--- a/torch/csrc/jit/mobile/function.h
+++ b/torch/csrc/jit/mobile/function.h
@@ -17,6 +17,10 @@
 class TORCH_API Function : public torch::jit::Function {
  public:
   explicit Function(c10::QualifiedName name);
+  Function(
+      c10::QualifiedName name,
+      std::shared_ptr<Code> code,
+      at::optional<c10::FunctionSchema> schema);
   void run(Stack& stack) override;
   at::IValue operator()(Stack& stack);
   void ensure_defined() override {}
@@ -24,6 +28,9 @@
   const c10::QualifiedName& qualname() const override;
   bool call(Stack&, c10::function_ref<void(const mobile::Code&)>) override;
 
+  // NOTE: the APIs below is dangerous: if you call append_instruction with
+  // dbg_handle and then call it without; then the dbg_handle will become
+  // misaligned. Therefore only use ONE variant at time.
   void append_instruction(OpCode op, int X, int N, int64_t dbg_handle);
   void append_instruction(OpCode op, int X, int N);
   bool append_operator(
@@ -56,6 +63,11 @@
   at::optional<c10::FunctionSchema> schema_; // (byte-code version 4+)
 };
 
+c10::optional<std::function<void(Stack&)>> makeOperatorFunction(
+    c10::OperatorName opname,
+    c10::optional<int> num_specified_args,
+    int64_t model_version);
+
 } // namespace mobile
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/mobile/interpreter.cpp b/torch/csrc/jit/mobile/interpreter.cpp
index 99e60c4..fddd3aa 100644
--- a/torch/csrc/jit/mobile/interpreter.cpp
+++ b/torch/csrc/jit/mobile/interpreter.cpp
@@ -94,15 +94,15 @@
         debug_handle = *handle;
       }
 
-      // std::cout << "RUNNING " << pc << " "
-      //           << code_->instructions_with_handles_[pc].instruction;
+      // std::cout << "RUNNING " << pc << " " << code.instructions_[pc];
       // if (inst.op == OP) {
-      //   std::cout << ", " << code_->op_names_[inst.X].name;
-      //   if (!code_->op_names_[inst.X].overload_name.empty()) {
-      //     std::cout << "." << code_->op_names_[inst.X].overload_name;
+      //   std::cout << ", " << code.op_names_[inst.X].name;
+      //   if (!code.op_names_[inst.X].overload_name.empty()) {
+      //     std::cout << "." << code.op_names_[inst.X].overload_name;
       //   }
       // }
       // std::cout << std::endl;
+      // std::cout << "top " << stack.back().tagKind() << std::endl;
 
       // TODO(iliacher): remove the workaround after RecordFunction is in
       // Dispatcher
diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h
index 4d00087..31d3f58 100644
--- a/torch/csrc/jit/mobile/module.h
+++ b/torch/csrc/jit/mobile/module.h
@@ -135,7 +135,7 @@
   std::unordered_map<std::string, std::string> metadata_;
   std::shared_ptr<CompilationUnit> cu_;
   MobileDebugTable debug_table_;
-  bool has_debug_handles_;
+  bool has_debug_handles_ = false;
 };
 } // namespace mobile
 } // namespace jit
diff --git a/torch/csrc/jit/runtime/interpreter.cpp b/torch/csrc/jit/runtime/interpreter.cpp
index d102a50..32e567e 100644
--- a/torch/csrc/jit/runtime/interpreter.cpp
+++ b/torch/csrc/jit/runtime/interpreter.cpp
@@ -33,7 +33,6 @@
 #endif
 
 #include <exception>
-#include <iostream>
 #include <memory>
 #include <mutex>
 #include <ostream>
diff --git a/torch/csrc/jit/serialization/export_bytecode.cpp b/torch/csrc/jit/serialization/export_bytecode.cpp
index 2e37fa7..a1348dd 100644
--- a/torch/csrc/jit/serialization/export_bytecode.cpp
+++ b/torch/csrc/jit/serialization/export_bytecode.cpp
@@ -1,23 +1,333 @@
 #include <torch/csrc/jit/serialization/export_bytecode.h>
+#include <utility>
 
 #include <torch/csrc/jit/runtime/instruction.h>
 #include <torch/csrc/jit/serialization/export.h>
 
+#include <c10/util/Exception.h>
+#include <torch/csrc/jit/api/function_impl.h>
+#include <torch/csrc/jit/api/method.h>
+#include <torch/csrc/jit/backends/backend_debug_handler.h>
+#include <torch/csrc/jit/backends/backend_debug_info.h>
+#include <torch/csrc/jit/frontend/source_range.h>
+#include <torch/csrc/jit/ir/attributes.h>
+#include <torch/csrc/jit/ir/ir.h>
+#include <torch/csrc/jit/ir/type_hashing.h>
+#include <torch/csrc/jit/mobile/function.h>
+#include <torch/csrc/jit/mobile/interpreter.h>
+#include <torch/csrc/jit/mobile/method.h>
+#include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/passes/inliner.h>
+#include <torch/csrc/jit/serialization/callstack_debug_info_serialization.h>
+#include <torch/csrc/jit/serialization/import_export_constants.h>
+#include <torch/csrc/jit/serialization/import_export_functions.h>
+#include <torch/csrc/jit/serialization/import_export_helpers.h>
+#include <torch/csrc/jit/serialization/pickle.h>
+#include <torch/csrc/jit/serialization/python_print.h>
+#include <torch/csrc/jit/serialization/source_range_serialization.h>
+#include <torch/csrc/jit/serialization/type_name_uniquer.h>
+
+#include <caffe2/serialize/inline_container.h>
+
 namespace torch {
 namespace jit {
 
-void BytecodeExportSet::add(
-    const c10::QualifiedName& qn,
-    ExportedFunction exported) {
-  items_.emplace(qn, std::move(exported));
+std::vector<Method> gatherGetSetStates(ObjectPtr obj) {
+  std::vector<Method> methods;
+  // Use DFS on IValue's to traverse dependencies of module._ivalue and
+  // add all setstate/getstates to initial stack.
+  std::vector<ObjectPtr> ivalue_stack;
+  ivalue_stack.emplace_back(obj);
+  while (!ivalue_stack.empty()) {
+    ObjectPtr cur = ivalue_stack.back();
+    ivalue_stack.pop_back();
+    auto type = cur->type();
+    Function* setstate = type->findMethod("__setstate__");
+    Function* getstate = type->findMethod("__getstate__");
+    if (getstate && setstate) {
+      if (setstate->isGraphFunction()) {
+        methods.emplace_back(cur, setstate);
+      }
+      if (getstate->isGraphFunction()) {
+        methods.emplace_back(cur, getstate);
+      }
+    } else {
+      for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
+        IValue field = cur->getSlot(i);
+        if (field.isObject()) {
+          ivalue_stack.emplace_back(field.toObject());
+        }
+      }
+    }
+  }
+  return methods;
 }
 
-void BytecodeExportSet::update(const c10::QualifiedName& qn, bool toplevel) {
-  items_.at(qn).toplevel = toplevel;
+std::vector<Method> findAllDependentFunctions(
+    const Module& module,
+    Graph& graph) {
+  std::vector<Method> methods;
+  std::unordered_set<c10::string_view> called_method_names;
+  auto nodes = findAllNodes(graph, c10::prim::CallMethod, true);
+  for (Node* node : nodes) {
+    if (auto iface = node->input(0)->type()->castRaw<InterfaceType>()) {
+      const FunctionSchema* schema = iface->getMethod(node->s(attr::name));
+      called_method_names.insert(schema->name());
+    }
+  }
+
+  for (const auto& submodule : module.modules()) {
+    for (const auto& m : submodule.get_methods()) {
+      if (called_method_names.find(m.function().qualname().name()) !=
+          called_method_names.end()) {
+        methods.emplace_back(m);
+      }
+    }
+  }
+  return methods;
 }
 
-bool BytecodeExportSet::contains(const c10::QualifiedName& qn) const {
-  return items_.find(qn) != items_.end();
+// NOTE: order of functions returned will be:
+// 1. functions originated from the methods passed in will be first
+// 2. All the dependent functions will come afterwards.
+// This order is meaningful because currently mobile Module looks up
+// methods with linear search.
+std::vector<std::unique_ptr<GraphFunction>> inlineFunctions(
+    const std::vector<Method>& initial_methods,
+    bool incl_dependent_functions) {
+  std::set<std::pair<std::string, Function*>> visited;
+  std::deque<Method> stack;
+  std::copy(
+      initial_methods.begin(),
+      initial_methods.end(),
+      std::back_inserter(stack));
+  std::vector<std::unique_ptr<GraphFunction>> inlined_functions;
+  while (!stack.empty()) {
+    Method cur = stack.front();
+    stack.pop_front();
+    auto tup = std::make_pair(
+        cur.owner()._ivalue()->type()->name()->qualifiedName(),
+        &cur.function());
+    if (visited.find(tup) != visited.end()) {
+      continue;
+    }
+    visited.insert(tup);
+    const auto& f = toGraphFunction(cur.function());
+    auto graph = f.graph()->copyUnique();
+    Inline(*graph);
+    c10::QualifiedName qn(*cur.owner()._ivalue()->type()->name(), f.name());
+
+    if (incl_dependent_functions) {
+      std::vector<Method> dependent_methods =
+          findAllDependentFunctions(cur.owner(), *graph);
+      std::copy(
+          dependent_methods.begin(),
+          dependent_methods.end(),
+          std::back_inserter(stack));
+    }
+    auto inlined_func = std::make_unique<GraphFunction>(
+        qn, std::move(graph), f.function_creator());
+    inlined_func->setSchema(f.getSchema());
+    inlined_functions.emplace_back(std::move(inlined_func));
+  }
+  return inlined_functions;
+}
+
+std::unique_ptr<mobile::Code> compileGraphToMobileCode(
+    const std::string& name,
+    const std::shared_ptr<Graph>& graph,
+    const CompilationOptions& compilation_options,
+    BackendDebugInfoRecorder& debug_info_recorder) {
+  MobileCode code(
+      graph,
+      name,
+      compilation_options.enable_default_value_for_unspecified_arg,
+      compilation_options.enable_default_args_before_out_args);
+
+  std::unique_ptr<mobile::Code> mobile_code_ptr =
+      std::make_unique<mobile::Code>();
+  mobile::Code& mobile_code = *mobile_code_ptr;
+
+  // operator names
+  std::vector<std::string> method_names;
+  std::vector<int64_t> op_debug_handles;
+  int next_new_op_index = 0;
+
+  auto op_to_specified_args = code.op_to_num_specified_args();
+
+  for (size_t i = 0; i < code.instructions().size(); ++i) {
+    Instruction ins = code.instructions()[i];
+
+    if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
+      // Found a new op (assumes new operators ordered by ascending ins.X)
+      auto node = code.instructions_source()[i];
+      const c10::OperatorName& opname = node->schema().operator_name();
+      auto unique_name = c10::toString(opname);
+      // For operator with vararg, adding default arguments would be confusing
+      // and is not allowed. For an operator with num_args = -1, it means the
+      // number of arguments is not available for this operator, we don't do any
+      // backward compatibility adaptation at runtime.
+      c10::optional<int> num_args = c10::nullopt;
+      auto it = op_to_specified_args.find(unique_name);
+      if (it != op_to_specified_args.end()) {
+        num_args = it->second;
+      }
+      mobile_code.operator_input_sizes_.emplace_back(num_args.value_or(-1));
+      mobile_code.op_names_.emplace_back(opname);
+      auto func = mobile::makeOperatorFunction(
+          opname, num_args, compilation_options.model_version);
+      TORCH_INTERNAL_ASSERT(
+          func.has_value(),
+          "Operator with name: ",
+          toString(opname),
+          " not found");
+      mobile_code.operators_.emplace_back(*func);
+      next_new_op_index++;
+    }
+    // CALL nodes at this point represent built-in (i.e. non-Graph)
+    // functions that were not inlined. Here we convert the CALL
+    // instructions for these functions into INTERFACE_CALL instructions
+    // s.t. at runtime, we will look up the Function* on the Type of the
+    // 0th argument in the stack and call that directly.
+    if (ins.op == CALL) {
+      auto node = code.instructions_source()[i];
+      if (node->kind() == prim::CallMethod) {
+        // NB: replacing instruction
+        auto method_name_idx =
+            code.constant_table().size() + method_names.size();
+        method_names.emplace_back(node->s(attr::name));
+        ins = Instruction{
+            INTERFACE_CALL,
+            static_cast<int32_t>(method_name_idx),
+            static_cast<uint16_t>(node->inputs().size())};
+      } else {
+        TORCH_INTERNAL_ASSERT(
+            false, "Unsupported node kind on CALL opcode for mobile");
+      }
+    } else if (ins.op == RET) {
+      auto node = code.instructions_source()[i];
+      for (const auto& input : node->inputs()) {
+        const auto& input_type = input->type();
+        if (input_type->kind() == TypeKind::ListType ||
+            input_type->kind() == TypeKind::DictType) {
+          for (const TypePtr& element_type : input_type->containedTypes()) {
+            TORCH_CHECK(
+                element_type->kind() != TypeKind::ClassType,
+                "Returining a list or dictionary with pytorch class type ",
+                "is not supported in mobile module "
+                "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
+                "Workaround: instead of using pytorch class as their element type, ",
+                "use a combination of list, dictionary, and single types.");
+          }
+        }
+      }
+    } else {
+      TORCH_CHECK(
+          isOpSupportedInMobile(ins.op),
+          toString(ins.op),
+          " is not supported in mobile module.");
+    }
+    auto node = code.instructions_source()[i];
+    int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
+    // Note 1-to-1 correspondence between instructions and debug handles
+    mobile_code.instructions_.emplace_back(ins);
+    mobile_code.debug_handles_.emplace_back(debug_handle);
+  }
+
+  // copy constants
+  mobile_code.constants_ = code.constant_table();
+
+  // Make a copy of the constants and append the method names
+  // that we emitted for the converted INTERFACE_CALL nodes above.
+  for (auto& method_name : method_names) {
+    mobile_code.constants_.emplace_back(method_name);
+  }
+
+  mobile_code.types_ = code.type_table();
+  mobile_code.register_size_ = code.register_size();
+  return mobile_code_ptr;
+}
+
+void checkSchema(const FunctionSchema& schema) {
+  TORCH_CHECK(
+      schema.overload_name().empty(), // @TODO: is this check correct?
+      "Overloads are not supported in mobile modules.");
+  TORCH_CHECK(
+      !schema.is_vararg(), "Python *args are not supported in mobile modules.");
+  TORCH_CHECK(
+      !schema.is_varret(),
+      "A variable number of return values is not supported in mobile modules.");
+}
+
+bool isLoweredModule(const Module& m) {
+  c10::QualifiedName type_name;
+  if (m.type()->name()) {
+    type_name = m.type()->name().value();
+  }
+  bool isLoweredModule = false;
+  for (const auto& atom : type_name.atoms()) {
+    if (atom == "LoweredModule") {
+      isLoweredModule = true;
+      break;
+    }
+  }
+  return isLoweredModule;
+}
+
+// Check if the global static map of backend debug info
+// contains debug info for this module and any of its children.
+// If so combine all the maps together and return one.
+void getBackendDebugInfoMap(
+    const Module& m,
+    BackendDebugInfoMapType& debug_map) {
+  if (isLoweredModule(m)) {
+    auto backend_debug_info =
+        m.attr("__backend_debug_info").toCustomClass<PyTorchBackendDebugInfo>();
+    const auto& map = backend_debug_info->getDebugInfoMap();
+    if (map) {
+      debug_map.insert(map.value().begin(), map.value().end());
+    }
+  }
+  for (const auto& c : m.children()) {
+    getBackendDebugInfoMap(c, debug_map);
+  }
+}
+
+mobile::Module jitModuleToMobile(
+    const Module& module,
+    const CompilationOptions& options) {
+  std::shared_ptr<mobile::CompilationUnit> mcu =
+      std::make_shared<mobile::CompilationUnit>();
+  BackendDebugInfoRecorder debug_info_recorder;
+
+  std::vector<Method> methods_to_export = module.get_methods();
+  std::vector<Method> getsetstates = gatherGetSetStates(module._ivalue());
+  std::copy(
+      getsetstates.begin(),
+      getsetstates.end(),
+      std::back_inserter(methods_to_export));
+
+  for (const auto& func :
+       inlineFunctions(methods_to_export, options.incl_interface_call)) {
+    std::shared_ptr<mobile::Code> mobile_code_ptr = compileGraphToMobileCode(
+        func->name(), func->graph(), options, debug_info_recorder);
+    const auto& schema = func->getSchema();
+    checkSchema(schema);
+    auto mobile_func = std::make_unique<mobile::Function>(
+        func->qualname(), mobile_code_ptr, schema);
+    mcu->register_function(std::move(mobile_func));
+  }
+
+  mobile::Module m(module._ivalue(), mcu);
+  m.setHasDebugHandles(true);
+  BackendDebugInfoMapType backend_debug_info_map;
+  getBackendDebugInfoMap(module, backend_debug_info_map);
+  auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
+  debug_handle_cs_ptr_map.insert(
+      backend_debug_info_map.begin(), backend_debug_info_map.end());
+  m.setDebugTable(MobileDebugTable(
+      debug_handle_cs_ptr_map.begin(), debug_handle_cs_ptr_map.end()));
+  return m;
 }
 
 } // namespace jit
diff --git a/torch/csrc/jit/serialization/export_bytecode.h b/torch/csrc/jit/serialization/export_bytecode.h
index f97434b..8782b5d 100644
--- a/torch/csrc/jit/serialization/export_bytecode.h
+++ b/torch/csrc/jit/serialization/export_bytecode.h
@@ -1,59 +1,31 @@
 #pragma once
 
+#include <tuple>
 #include <unordered_map>
 
 #include <ATen/core/function_schema.h>
 #include <ATen/core/ivalue.h>
+#include <ATen/core/jit_type.h>
 #include <ATen/core/qualified_name.h>
 #include <torch/csrc/jit/backends/backend_debug_handler.h>
+#include <torch/csrc/jit/mobile/function.h>
+#include <torch/csrc/jit/mobile/module.h>
 #include <torch/csrc/jit/runtime/interpreter.h>
 #include <torch/csrc/jit/serialization/type_name_uniquer.h>
 
 namespace torch {
 namespace jit {
 
-struct ExportedFunction {
-  ExportedFunction(
-      const Module& m,
-      const Function& f,
-      std::unique_ptr<Graph> g,
-      bool t)
-      : mod(m), function(f), optimizedGraph(std::move(g)), toplevel(t) {}
-  Module mod;
-  const Function& function;
-  std::unique_ptr<Graph> optimizedGraph;
-  bool toplevel;
+struct TORCH_API CompilationOptions {
+  bool incl_interface_call = false;
+  bool enable_default_value_for_unspecified_arg = false;
+  bool enable_default_args_before_out_args = true;
+  int model_version = caffe2::serialize::kProducedBytecodeVersion;
 };
 
-class TORCH_API BytecodeExportSet {
- public:
-  BytecodeExportSet() = default;
-  BytecodeExportSet(const BytecodeExportSet&) = delete;
-  BytecodeExportSet& operator=(const BytecodeExportSet&) = delete;
-  BytecodeExportSet(BytecodeExportSet&&) = default;
-  BytecodeExportSet& operator=(BytecodeExportSet&&) = default;
-
-  void add(const c10::QualifiedName& qn, ExportedFunction);
-  void update(const c10::QualifiedName& qn, bool toplevel);
-  bool contains(const c10::QualifiedName& qn) const;
-
-  template <typename F>
-  void visit(F&& f) {
-    for (auto& item : items_) {
-      if (item.second.toplevel) {
-        f(item.first, item.second);
-      }
-    }
-    for (auto& item : items_) {
-      if (!item.second.toplevel) {
-        f(item.first, item.second);
-      }
-    }
-  }
-
- private:
-  std::unordered_map<c10::QualifiedName, ExportedFunction> items_;
-};
+TORCH_API mobile::Module jitModuleToMobile(
+    const Module& module,
+    const CompilationOptions& options);
 
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/serialization/export_module.cpp b/torch/csrc/jit/serialization/export_module.cpp
index 61d39fa..a13193b 100644
--- a/torch/csrc/jit/serialization/export_module.cpp
+++ b/torch/csrc/jit/serialization/export_module.cpp
@@ -38,6 +38,18 @@
 namespace torch {
 namespace jit {
 
+CompilationOptions getOptionsFromGlobal() {
+  CompilationOptions compilation_options;
+  compilation_options.enable_default_args_before_out_args =
+      BytecodeEmitMode::is_default_args_before_out_args_enabled();
+  compilation_options.enable_default_value_for_unspecified_arg =
+      BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled();
+  compilation_options.incl_interface_call = getMobileInterfaceCallExport();
+  compilation_options.model_version =
+      caffe2::serialize::kProducedBytecodeVersion;
+  return compilation_options;
+}
+
 IValue to_tuple(std::initializer_list<IValue> ivalues) {
   return c10::ivalue::Tuple::create(ivalues);
 }
@@ -63,138 +75,49 @@
 }
 
 std::pair<IValue, IValue> getFunctionTuple(
-    const Module& module,
-    const Function& func,
-    std::unique_ptr<Graph> optimizedGraph,
+    const CompilationUnit& compilation_unit,
+    const mobile::Function& func,
     BackendDebugInfoRecorder& debug_info_recorder,
-    const std::string& qn,
     TypeNameUniquer& type_name_uniquer_) {
-  TORCH_INTERNAL_ASSERT(optimizedGraph);
-  std::shared_ptr<MobileCode> code;
-  code = std::make_shared<MobileCode>(
-      std::move(optimizedGraph), func.name(), BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled() /* emit_default_input_instructions */, BytecodeEmitMode::is_default_args_before_out_args_enabled() /* enable_defaults_args_with_out_args */);
-  auto instructions_copy = code->instructions();
-
-  // operator names
-  std::vector<c10::OperatorName> opnames;
-  std::vector<std::string> method_names;
-  std::vector<int64_t> op_debug_handles;
-  int next_new_op_index = 0;
-  for (size_t i = 0; i < instructions_copy.size(); ++i) {
-    Instruction ins = instructions_copy[i];
-    if ((ins.op == OP || ins.op == OPN) && ins.X == next_new_op_index) {
-      // Found a new op (assumes new operators ordered by ascending ins.X)
-      auto node = code->instructions_source()[i];
-      opnames.emplace_back(node->schema().operator_name());
-      next_new_op_index++;
-    }
-    // CALL nodes at this point represent built-in (i.e. non-Graph)
-    // functions that were not inlined. Here we convert the CALL
-    // instructions for these functions into INTERFACE_CALL instructions
-    // s.t. at runtime, we will look up the Function* on the Type of the
-    // 0th argument in the stack and call that directly.
-    if (ins.op == CALL) {
-      auto node = code->instructions_source()[i];
-      if (node->kind() == prim::CallMethod) {
-        // NB: replacing instruction
-        auto method_name_idx =
-            code->constant_table().size() + method_names.size();
-        method_names.emplace_back(node->s(attr::name));
-        Instruction new_instr{
-            INTERFACE_CALL,
-            static_cast<int32_t>(method_name_idx),
-            static_cast<uint16_t>(node->inputs().size())};
-        instructions_copy[i] = new_instr;
-      } else {
-        TORCH_INTERNAL_ASSERT(
-            false, "Unsupported node kind on CALL opcode for mobile");
-      }
-    } else if (ins.op == RET) {
-      auto node = code->instructions_source()[i];
-      for (const auto& input : node->inputs()) {
-        const auto& input_type = input->type();
-        if (input_type->kind() == TypeKind::ListType ||
-            input_type->kind() == TypeKind::DictType) {
-          for (const TypePtr& element_type : input_type->containedTypes()) {
-            TORCH_CHECK(
-                element_type->kind() != TypeKind::ClassType,
-                "Returining a list or dictionary with pytorch class type ",
-                "is not supported in mobile module "
-                "(List[Foo] or Dict[int, Foo] for class Foo(torch.nn.Module)). "
-                "Workaround: instead of using pytorch class as their element type, ",
-                "use a combination of list, dictionary, and single types.");
-          }
-        }
-      }
-    } else {
-      TORCH_CHECK(
-          isOpSupportedInMobile(ins.op),
-          toString(ins.op),
-          " is not supported in mobile module.");
-    }
-    auto node = code->instructions_source()[i];
-    int64_t debug_handle = debug_info_recorder.getNextDebugHandle(node);
-    // Note 1-to-1 correspondence between instructions and debug handles
-    op_debug_handles.emplace_back(debug_handle);
-  }
+  const std::shared_ptr<mobile::Code> mobile_code_ptr = func.get_code();
 
   // instructions
   std::vector<IValue> instructions;
-  instructions.reserve(instructions_copy.size());
-  for (Instruction ins : instructions_copy) {
+  instructions.reserve(mobile_code_ptr->instructions_.size());
+  for (Instruction ins : mobile_code_ptr->instructions_) {
     instructions.emplace_back(to_tuple({toString(ins.op), ins.X, ins.N}));
   }
 
   // operators
   std::vector<IValue> operators;
-  auto op_to_specified_args = code->op_to_num_specified_args();
-  operators.reserve(opnames.size());
-  for (const auto& opname : opnames) {
-    auto unique_name = c10::toString(opname);
-    // For operator with vararg, adding default arguments would be confusing and
-    // is not allowed. For an operator with num_args = -1, it means the number
-    // of arguments is not available for this operator, we don't do any backward
-    // compatibility adaptation at runtime.
-    int num_args = -1;
-    auto it = op_to_specified_args.find(unique_name);
-    if (it != op_to_specified_args.end()) {
-      num_args = it->second;
-    }
+  operators.reserve(mobile_code_ptr->op_names_.size());
+  for (int i = 0; i < mobile_code_ptr->op_names_.size(); ++i) {
+    const auto& opname = mobile_code_ptr->op_names_[i];
+    const int size = mobile_code_ptr->operator_input_sizes_[i];
     if (BytecodeEmitMode::is_default_value_for_unspecified_arg_enabled()) {
       operators.emplace_back(to_tuple({opname.name, opname.overload_name}));
     } else {
       operators.emplace_back(
-          to_tuple({opname.name, opname.overload_name, num_args}));
+          to_tuple({opname.name, opname.overload_name, size}));
     }
   }
 
-  // constants
-  //
-  // Make a copy of the constants and append the method names
-  // that we emitted for the converted INTERFACE_CALL nodes above.
-  auto constants = code->constant_table();
-  for (auto& method_name : method_names) {
-    constants.emplace_back(std::move(method_name));
-  }
-
   // types
   std::vector<IValue> types;
-  types.reserve(code->type_table().size());
+  types.reserve(mobile_code_ptr->types_.size());
   static const std::string torch_prefix("__torch__");
   static const std::string class_prefix("__torch__.torch.classes");
-  std::shared_ptr<torch::jit::CompilationUnit> cu =
-      module._ivalue()->compilation_unit();
 
-  for (const TypePtr& t : code->type_table()) {
+  for (const TypePtr& t : mobile_code_ptr->types_) {
     std::string type_str = t->annotation_str();
     if (t->kind() == TypeKind::TupleType) {
       TORCH_CHECK(
-          cu->get_named_tuple(t->str()),
+          compilation_unit.get_named_tuple(t->str()),
           "Can't find definition for the qualified name: ",
           t->str(),
           "(TypeKind::TupleType)  in compilation unit.",
           "Please report a bug to PyTorch.");
-      auto named_tuple_type = cu->get_named_tuple(t->str());
+      auto named_tuple_type = compilation_unit.get_named_tuple(t->str());
       if (named_tuple_type != nullptr) {
         std::string named_tuple_str = t->str();
         named_tuple_str.append("[NamedTuple, [");
@@ -254,12 +177,12 @@
 
   // since the register location is embedded into the bytecode, pass the
   // register size
-  auto register_size = static_cast<int>(code->register_size());
+  auto register_size = static_cast<int>(mobile_code_ptr->register_size_);
 
   auto codeTable = Table(
       {{"instructions", to_tuple(instructions)},
        {"operators", to_tuple(operators)},
-       {"constants", to_tuple(constants)},
+       {"constants", to_tuple(mobile_code_ptr->constants_)},
        {"types", to_tuple(types)},
        {"register_size", register_size}});
 
@@ -273,14 +196,7 @@
     }
     return c10::nullopt;
   };
-  TORCH_CHECK(
-      schema.overload_name().empty(), // @TODO: is this check correct?
-      "Overloads are not supported in mobile modules.");
-  TORCH_CHECK(
-      !schema.is_vararg(), "Python *args are not supported in mobile modules.");
-  TORCH_CHECK(
-      !schema.is_varret(),
-      "A variable number of return values is not supported in mobile modules.");
+
   auto makeArgTuple = [&](const std::vector<Argument>& args) {
     std::vector<IValue> argTables;
     for (auto&& arg : args) {
@@ -315,6 +231,17 @@
   });
 
   // function tuple
+  std::string qn;
+  if (func.name() == "__setstate__" || func.name() == "__getstate__") {
+    auto classtype = func.getSchema().arguments()[0].type()->cast<ClassType>();
+    TORCH_INTERNAL_ASSERT(
+        classtype, "class is null ", func.qualname().qualifiedName());
+    qn = c10::QualifiedName(
+             type_name_uniquer_.getUniqueName(classtype), func.name())
+             .qualifiedName();
+  } else {
+    qn = func.qualname().qualifiedName();
+  }
   auto bytecode_vals = to_tuple({qn, codeTable, schemaTable});
 
   c10::optional<IValue> debug_info_vals;
@@ -324,41 +251,27 @@
   // debug handles generated by debug_handle_manager
   // will correspond to {source_range, inlinedCallStackPtr} which we will
   // serialize separately.
-  IValue module_debug_tuple = c10::ivalue::Tuple::create(op_debug_handles);
+  IValue module_debug_tuple =
+      c10::ivalue::Tuple::create(mobile_code_ptr->debug_handles_);
   auto function_debug_info =
       Table({{"function_debug_handles", module_debug_tuple}});
   debug_info_vals = to_tuple({qn, function_debug_info});
   return std::make_pair(bytecode_vals, debug_info_vals);
 }
 
-void pushFunctionToIValues(
-    BytecodeExportSet exportSet,
+void pushMobileFunctionsToIValues(
+    const CompilationUnit& compilation_unit,
+    const mobile::Module& module,
     std::vector<c10::IValue>& elements,
     std::vector<c10::IValue>& debugInfoElements,
     BackendDebugInfoRecorder& recorder,
     TypeNameUniquer& uniquer) {
-  exportSet.visit(
-      [&](const c10::QualifiedName& qn, ExportedFunction& exported) {
-        auto tuple = getFunctionTuple(
-            exported.mod,
-            exported.function,
-            std::move(exported.optimizedGraph),
-            recorder,
-            qn.qualifiedName(),
-            uniquer);
-        elements.push_back(std::move(tuple.first));
-        debugInfoElements.push_back(std::move(tuple.second));
-      });
-}
-
-void pushFunctionToIValues(
-    BytecodeExportSet exportSet,
-    std::vector<c10::IValue>& elements,
-    BackendDebugInfoRecorder& recorder,
-    TypeNameUniquer& uniquer) {
-  std::vector<c10::IValue> debugInfoElements;
-  pushFunctionToIValues(
-      std::move(exportSet), elements, debugInfoElements, recorder, uniquer);
+  for (const auto& method : module.get_methods()) {
+    auto tuple = getFunctionTuple(
+        compilation_unit, method.function(), recorder, uniquer);
+    elements.push_back(std::move(tuple.first));
+    debugInfoElements.push_back(std::move(tuple.second));
+  }
 }
 
 std::unordered_set<const FunctionSchema*> getInterfaceCalls(Graph& graph) {
@@ -402,61 +315,6 @@
   return ret;
 }
 
-void exportFunction(
-    BytecodeExportSet& exportSet,
-    const ModuleMethod& method,
-    bool toplevel = false) {
-  const auto& func = method.function;
-  const auto& qn = method.exportName;
-  if (exportSet.contains(qn)) {
-    if (toplevel) {
-      exportSet.update(qn, toplevel);
-    }
-    return;
-  }
-  auto graph = func.graph()->copyUnique();
-  Inline(*graph);
-  auto interfaceCalls = getInterfaceCalls(*graph);
-  exportSet.add(
-      qn, ExportedFunction{method.module, func, std::move(graph), toplevel});
-
-  if (!getMobileInterfaceCallExport()) {
-    return;
-  }
-
-  auto interfaces = getModuleInterfaceExports(method.module, interfaceCalls);
-  for (const auto& interface : interfaces) {
-    exportFunction(exportSet, interface);
-  }
-}
-
-void setstateTuple(
-    BytecodeExportSet& exportSet,
-    const Module& module,
-    const IValue& ivalue,
-    TypeNameUniquer& type_name_uniquer_,
-    bool toplevel = false) {
-  if (!ivalue.isObject())
-    return;
-  auto obj = ivalue.toObject();
-  auto type = obj->type();
-  if (checkHasValidSetGetState(type)) {
-    Function& setstate = type->getMethod("__setstate__");
-    auto qn = type_name_uniquer_.getUniqueName(obj->type()).qualifiedName() +
-        "." + setstate.name();
-    if (exportSet.contains(qn)) {
-      return;
-    }
-    if (auto f = tryToGraphFunction(setstate)) {
-      exportFunction(exportSet, ModuleMethod{module, *f, qn}, toplevel);
-    }
-  } else {
-    for (size_t i = 0, n = type->numAttributes(); i < n; ++i) {
-      setstateTuple(exportSet, module, obj->getSlot(i), type_name_uniquer_);
-    }
-  }
-}
-
 bool isLoweredModule(const Module& m) {
   c10::QualifiedName type_name;
   if (m.type()->name()) {
@@ -544,24 +402,6 @@
   return mobileInterfaceCallExport().load(std::memory_order_relaxed);
 }
 
-BytecodeExportSet moduleMethodsTuple(
-    const Module& module,
-    TypeNameUniquer& type_name_uniquer_) {
-  BytecodeExportSet exportSet;
-  auto methods = module.get_methods();
-  // top level methods
-  for (const auto& method : methods) {
-    const auto& f = toGraphFunction(method.function());
-    exportFunction(
-        exportSet, ModuleMethod{module, f, f.qualname()}, /* toplevel */ true);
-  }
-
-  // __setstate__ of all components
-  setstateTuple(exportSet, module, module._ivalue(), type_name_uniquer_, true);
-
-  return exportSet;
-}
-
 void SetExportModuleExtraFilesHook(ExportModuleExtraFilesHook hook) {
   GetExtraFilesHook() = std::move(hook);
 }
@@ -774,9 +614,12 @@
   // Always save debug handles
   debug_info_elements.emplace_back(static_cast<int64_t>(version_to_write));
 
-  BytecodeExportSet exportSet = moduleMethodsTuple(module, type_name_uniquer_);
-  pushFunctionToIValues(
-      std::move(exportSet),
+  mobile::Module mobile_module =
+      jitModuleToMobile(module, getOptionsFromGlobal());
+
+  pushMobileFunctionsToIValues(
+      *module._ivalue()->compilation_unit(),
+      mobile_module,
       elements,
       debug_info_elements,
       debug_info_recorder,
@@ -840,9 +683,9 @@
     getBackendDebugInfoMap(module, backend_debug_info_map);
     // Now get the debug-handles-to-inlined-cs-ptr-map
     // And serialize that in a separate archive
-    auto debug_handle_cs_ptr_map = debug_info_recorder.stopRecording();
-    debug_handle_cs_ptr_map.insert(
-        backend_debug_info_map.begin(), backend_debug_info_map.end());
+    const auto& debug_info = mobile_module.getDebugTable().getCallStackPtrMap();
+    BackendDebugInfoMapType debug_handle_cs_ptr_map(
+        debug_info.begin(), debug_info.end());
     CallStackDebugInfoPickler cs_debug_info_pickler;
     auto cs_data = cs_debug_info_pickler.pickle(
         debug_handle_cs_ptr_map, source_range_tags_);
@@ -962,31 +805,13 @@
 
 namespace {
 void export_opnames(const script::Module& m, std::set<std::string>& opnames) {
-  std::vector<c10::IValue> elements;
-  BackendDebugInfoRecorder dummy;
-  TypeNameUniquer dummy_uniquer = TypeNameUniquer();
-  BytecodeExportSet exportSet = moduleMethodsTuple(m, dummy_uniquer);
-  pushFunctionToIValues(std::move(exportSet), elements, dummy, dummy_uniquer);
-  for (const auto& element : elements) {
-    auto table = element.toTupleRef().elements()[1];
-    auto row =
-        table.toTupleRef().elements().at(BYTECODE_INDEX_OPERATOR).toTuple();
-    TORCH_INTERNAL_ASSERT(
-        row->elements().at(0).toStringRef() == "operators",
-        "Expected operators but found ",
-        row->elements().at(0).toStringRef());
-    const auto& ops_list = row->elements().at(1).toTupleRef().elements();
-    for (const auto& op : ops_list) {
-      const auto& op_item = op.toTupleRef().elements();
-      TORCH_CHECK(
-          op_item.size() >= 2,
-          "There should be either two parts (name and overload name), ",
-          "or three parts (name, overload name and number of specified args) ",
-          "for an operator.");
-      auto opname = op_item[0].toString()->string();
-      auto overload = op_item[1].toString()->string();
+  mobile::Module mobile_m = jitModuleToMobile(m, getOptionsFromGlobal());
+  for (const auto& method : mobile_m.get_methods()) {
+    for (const auto& op : method.function().get_code()->op_names_) {
       // NOLINTNEXTLINE(performance-inefficient-string-concatenation)
-      opnames.emplace(overload.empty() ? opname : opname + "." + overload);
+      opnames.emplace(
+          op.overload_name.empty() ? op.name
+                                   : op.name + "." + op.overload_name);
     }
   }
 }