[Pytorch][Ondevice quantization] Add device side API to convert model (#83807)

Summary:
This diff adds device side API which will convert the model to its
quantized equivalent. THe input model must have been prepared AOT for
quantization.

API is implemented by:
- Running reset obervers
- Running observe method
- Running quantize method
- And replacing method, e.g. forward, with its quantized equivalent.

Test Plan:
test/quantization/jit/test_ondevice_quantization.py

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D38889818](https://our.internmc.facebook.com/intern/diff/D38889818)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83807
Approved by: https://github.com/iseeyuan
diff --git a/buckbuild.bzl b/buckbuild.bzl
index ae1519e..76e0db9 100644
--- a/buckbuild.bzl
+++ b/buckbuild.bzl
@@ -1417,6 +1417,7 @@
             "torch/csrc/autograd/VariableTypeManual.cpp",
             "torch/csrc/autograd/FunctionsManual.cpp",
             "torch/csrc/api/src/data/datasets/mnist.cpp",
+            "torch/csrc/jit/mobile/quantization.cpp",
             "torch/csrc/jit/mobile/train/export_data.cpp",
             "torch/csrc/jit/mobile/train/optim/sgd.cpp",
             "torch/csrc/jit/mobile/train/random.cpp",
diff --git a/build_variables.bzl b/build_variables.bzl
index eb09a2a..ec08b91 100644
--- a/build_variables.bzl
+++ b/build_variables.bzl
@@ -564,6 +564,7 @@
     "torch/csrc/jit/mobile/observer.cpp",
     "torch/csrc/jit/mobile/parse_bytecode.cpp",
     "torch/csrc/jit/mobile/parse_operators.cpp",
+    "torch/csrc/jit/mobile/quantization.cpp",
     "torch/csrc/jit/mobile/upgrader_mobile.cpp",
     "torch/csrc/jit/runtime/register_prim_ops.cpp",
     "torch/csrc/jit/runtime/register_special_ops.cpp",
@@ -612,6 +613,7 @@
     "torch/csrc/jit/mobile/observer.cpp",
     "torch/csrc/jit/mobile/parse_bytecode.cpp",
     "torch/csrc/jit/mobile/parse_operators.cpp",
+    "torch/csrc/jit/mobile/quantization.cpp",
     "torch/csrc/jit/mobile/train/export_data.cpp",
     "torch/csrc/jit/mobile/train/optim/sgd.cpp",
     "torch/csrc/jit/mobile/train/random.cpp",
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index a904898..584d550 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -560,6 +560,7 @@
        ${TORCH_SRC_DIR}/csrc/jit/mobile/observer.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_bytecode.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/parse_operators.cpp
+       ${TORCH_SRC_DIR}/csrc/jit/mobile/quantization.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/train/export_data.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/train/optim/sgd.cpp
        ${TORCH_SRC_DIR}/csrc/jit/mobile/train/random.cpp
diff --git a/test/quantization/jit/test_ondevice_quantization.py b/test/quantization/jit/test_ondevice_quantization.py
index 8b453a5..fa3cfaa 100644
--- a/test/quantization/jit/test_ondevice_quantization.py
+++ b/test/quantization/jit/test_ondevice_quantization.py
@@ -2,6 +2,7 @@
 # Owner(s): ["oncall: quantization"]
 
 import torch
+import torch._C_flatbuffer
 
 from torch.ao.quantization import (
     default_dynamic_qconfig,
@@ -22,11 +23,13 @@
     LinearAddModel,
 )
 
-from torch.jit.mobile import _load_for_lite_interpreter
+from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule
 
 from torch.testing import FileCheck
+from torch.utils import bundled_inputs as bundled_inputs
 
 import io
+from typing import Dict
 
 class myMod(torch.nn.Module):
     def __init__(self, weight):
@@ -396,7 +399,7 @@
         self.assertTrue(thrown)
 
 
-    def _check_serialization_deserialization(self, model):
+    def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False):
         model.eval()
         inputs = model.get_example_inputs()
         ref_m = torch.jit.script(model)
@@ -410,27 +413,40 @@
         ref_m = torch.jit.load(buffer)
         ref_output = ref_m(*inputs)
 
-        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
-        buffer = io.BytesIO()
-        torch.jit.save(m, buffer)
-        buffer.seek(0)
-        m = torch.jit.load(buffer)
-        m.reset_observers_forward()
-        m.observe_forward(*inputs)
-        m.quantize_forward(*inputs)
-        output = m.quantized_forward(*inputs)
-        self.assertTrue(torch.allclose(ref_output, output))
+        if not check_device_side_api:
+            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
+            buffer = io.BytesIO()
+            torch.jit.save(m, buffer)
+            buffer.seek(0)
+            m = torch.jit.load(buffer)
+            m.reset_observers_forward()
+            m.observe_forward(*inputs)
+            m.quantize_forward(*inputs)
+            output = m.quantized_forward(*inputs)
+            self.assertTrue(torch.allclose(ref_output, output))
+        else:
+            # check for lite interpreter
+            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
+            first_input, = inputs
+            rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
+            m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
+            buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
+            buffer.seek(0)
+            m = _load_for_lite_interpreter(buffer)  # Error here
+            torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
+            self.assertFalse(m.find_method("quantized_forward"))
+            self.assertFalse(m.find_method("quantize_forward"))
+            self.assertFalse(m.find_method("observe_forward"))
+            self.assertFalse(m.find_method("reset_observers_forward"))
+            output = m(*inputs)
+            self.assertTrue(torch.allclose(ref_output, output))
 
-        # check for lite interpreter
-        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
-        buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
-        buffer.seek(0)
-        m = _load_for_lite_interpreter(buffer)  # Error here
-        m.run_method("reset_observers_forward")
-        m.run_method("observe_forward", *inputs)
-        m.run_method("quantize_forward", *inputs)
-        output = m.run_method("quantized_forward", *inputs)
-        self.assertTrue(torch.allclose(ref_output, output))
+            # Now serialize to flabuffer and load from fb and check
+            dict: Dict[str, str] = {}
+            bytes = torch._C_flatbuffer._save_mobile_module_to_bytes(m._c, dict)
+            m = LiteScriptModule(torch._C_flatbuffer._load_mobile_module_from_bytes(bytes))
+            fb_output = m(*inputs)
+            self.assertTrue(torch.allclose(ref_output, fb_output))
 
         model.eval()
         inputs = model.get_example_inputs()
@@ -445,27 +461,41 @@
         ref_m = torch.jit.load(buffer)
         ref_output = ref_m(*inputs)
 
-        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
-        buffer = io.BytesIO()
-        torch.jit.save(m, buffer)
-        buffer.seek(0)
-        m = torch.jit.load(buffer)
-        m.reset_observers_forward()
-        m.observe_forward(*inputs)
-        m.quantize_forward(*inputs)
-        output = m.quantized_forward(*inputs)
-        self.assertTrue(torch.allclose(ref_output, output))
+        if not check_device_side_api:
+            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
+            buffer = io.BytesIO()
+            torch.jit.save(m, buffer)
+            buffer.seek(0)
+            m = torch.jit.load(buffer)
+            m.reset_observers_forward()
+            m.observe_forward(*inputs)
+            m.quantize_forward(*inputs)
+            output = m.quantized_forward(*inputs)
+            self.assertTrue(torch.allclose(ref_output, output))
+        else:
+            # check for lite interpreter
+            m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
+            first_input, = inputs
+            rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype)
+            m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )])
+            buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
+            buffer.seek(0)
+            m = _load_for_lite_interpreter(buffer)  # Error here
+            torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward")
+            self.assertFalse(m.find_method("quantized_forward"))
+            self.assertFalse(m.find_method("quantize_forward"))
+            self.assertFalse(m.find_method("observe_forward"))
+            self.assertFalse(m.find_method("reset_observers_forward"))
+            output = m(*inputs)
+            self.assertTrue(torch.allclose(ref_output, output))
 
-        # check for lite interpreter
-        m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict)
-        buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter())
-        buffer.seek(0)
-        m = _load_for_lite_interpreter(buffer)  # Error here
-        m.run_method("reset_observers_forward")
-        m.run_method("observe_forward", *inputs)
-        m.run_method("quantize_forward", *inputs)
-        output = m.run_method("quantized_forward", *inputs)
-        self.assertTrue(torch.allclose(ref_output, output))
+
+    def _check_serialization_deserialization(self, model):
+        self._check_serdes_and_device_side_api_helper(model, False)
+
+
+    def _check_device_side_api(self, model):
+        self._check_serdes_and_device_side_api_helper(model, True)
 
 
     def test_quantize_forward(self):
@@ -492,3 +522,8 @@
     def test_serialization_deserialization(self):
         model = MyConvLinearModule()
         self._check_serialization_deserialization(model)
+
+
+    def test_device_side_api(self):
+        model = MyConvLinearModule()
+        self._check_device_side_api(model)
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 92a3631..024c5ae 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -298,6 +298,7 @@
 def _load_for_lite_interpreter(filename: Union[str, Path], map_location: Union[_device, str, None]): ...
 def _load_for_lite_interpreter_from_buffer(buffer: BinaryIO, map_location: Union[_device, str, None]): ...
 def _export_operator_list(module: LiteScriptModule): ...
+def _quantize_ondevice_ptq_dynamic(module: LiteScriptModule, method_name: str): ...
 def _get_model_bytecode_version(filename: Union[str, Path]) -> _int: ...
 def _get_model_bytecode_version_from_buffer(buffer: BinaryIO) -> _int: ...
 def _backport_for_mobile(filename_input: Union[str, Path], filename_output: Union[str, Path], to_version: _int) -> None: ...
diff --git a/torch/csrc/jit/mobile/module.cpp b/torch/csrc/jit/mobile/module.cpp
index 2ef7c34..5da8cb4 100644
--- a/torch/csrc/jit/mobile/module.cpp
+++ b/torch/csrc/jit/mobile/module.cpp
@@ -43,6 +43,50 @@
   AT_ERROR("Method '", name, "' is not defined.");
 }
 
+bool Module::compareMethodSchemas(
+    const std::string& name_1,
+    const std::string& name_2) {
+  c10::optional<c10::FunctionSchema> schema_1, schema_2;
+  for (const auto& fn : cu_->methods()) {
+    if (fn->name() == name_1) {
+      schema_1 = fn->getSchema();
+    }
+    if (fn->name() == name_2) {
+      schema_2 = fn->getSchema();
+    }
+  }
+  if (schema_1.has_value() && schema_2.has_value()) {
+    return (schema_1 == schema_2);
+  }
+  return false;
+}
+
+void Module::unsafeRemoveMethod(const std::string& basename) {
+  int64_t i = 0;
+  for (; i < cu_->methods().size(); ++i) {
+    if ((cu_->methods()[i])->name() == basename) {
+      break;
+    }
+  }
+  object_->type()->unsafeRemoveMethod(basename);
+  cu_->unsafeRemoveFunction(i);
+}
+
+void Module::unsafeCopyMethod(
+    const std::string& new_method_name,
+    const Function& to_be_copied) {
+  TORCH_CHECK(
+      !find_method(new_method_name).has_value(),
+      "Trying to replace existing method.");
+  const c10::QualifiedName& tobe_copied_name = to_be_copied.qualname();
+  c10::QualifiedName qualified_method_name(
+      tobe_copied_name.prefix(), new_method_name);
+  std::unique_ptr<Function> new_fn = std::make_unique<Function>(
+      qualified_method_name, to_be_copied.get_code(), to_be_copied.getSchema());
+  object_->type()->addMethod(new_fn.get());
+  cu_->register_function(std::move(new_fn));
+}
+
 c10::optional<Method> Module::find_method(const std::string& basename) const {
   for (const auto& fn : cu_->methods()) {
     if (fn->name() == basename) {
diff --git a/torch/csrc/jit/mobile/module.h b/torch/csrc/jit/mobile/module.h
index 01c76e1..2b07831 100644
--- a/torch/csrc/jit/mobile/module.h
+++ b/torch/csrc/jit/mobile/module.h
@@ -3,6 +3,7 @@
 #include <torch/csrc/jit/mobile/debug_info.h>
 #include <torch/csrc/jit/mobile/function.h>
 #include <torch/csrc/jit/mobile/method.h>
+#include <torch/csrc/jit/mobile/quantization.h>
 
 namespace torch {
 namespace jit {
@@ -42,6 +43,10 @@
   Function* find_function(const c10::QualifiedName& qn);
   const Function* find_function(const c10::QualifiedName& qn) const;
 
+  void unsafeRemoveFunction(const int64_t index) {
+    methods_.erase(methods_.begin() + index);
+  }
+
  private:
   std::vector<std::unique_ptr<Function>> methods_;
 };
@@ -71,6 +76,7 @@
     return get_method("forward")(std::move(inputs));
   }
   c10::optional<Method> find_method(const std::string& basename) const;
+
   const std::string name() const {
     return object_->name();
   }
@@ -152,6 +158,18 @@
   }
 
  private:
+  friend class quantization::PTQQuanizationHelper;
+
+  bool compareMethodSchemas(
+      const std::string& name_1,
+      const std::string& name_2);
+
+  void unsafeRemoveMethod(const std::string& basename);
+
+  void unsafeCopyMethod(
+      const std::string& new_method_name,
+      const Function& to_be_copied);
+
   c10::intrusive_ptr<c10::ivalue::Object> object_;
   std::unordered_map<std::string, std::string> metadata_;
   std::shared_ptr<CompilationUnit> cu_;
diff --git a/torch/csrc/jit/mobile/quantization.cpp b/torch/csrc/jit/mobile/quantization.cpp
new file mode 100644
index 0000000..b391cf5
--- /dev/null
+++ b/torch/csrc/jit/mobile/quantization.cpp
@@ -0,0 +1,66 @@
+#include <ATen/Context.h>
+#include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/mobile/quantization.h>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+namespace quantization {
+
+void PTQQuanizationHelper::quantize_dynamic(
+    torch::jit::mobile::Module& m,
+    const std::string& method_name) {
+  at::globalContext().setReleaseWeightsWhenPrepacking(false);
+  std::string reset_observers_method_name = "reset_observers_" + method_name;
+  std::string observe_method_name = "observe_" + method_name;
+  std::string quantize_method_name = "quantize_" + method_name;
+  std::string quantized_method_name = "quantized_" + method_name;
+
+  TORCH_CHECK(
+      m.find_method(reset_observers_method_name).has_value(),
+      "PTQ ready module must have",
+      reset_observers_method_name,
+      " method.");
+  TORCH_CHECK(
+      m.find_method(observe_method_name),
+      "PTQ ready module must have",
+      reset_observers_method_name,
+      " method.");
+  TORCH_CHECK(
+      m.find_method(quantize_method_name),
+      "PTQ ready module must have",
+      quantize_method_name,
+      " method.");
+  TORCH_CHECK(
+      m.find_method(quantized_method_name),
+      "PTQ ready module must have",
+      quantized_method_name,
+      " method.");
+  TORCH_CHECK(
+      m.find_method("get_all_bundled_inputs"),
+      "PTQ ready module must have get_all_bundled_inputs method.");
+
+  auto inputs = m.run_method("get_all_bundled_inputs")
+                    .toList()
+                    .get(0)
+                    .toTupleRef()
+                    .elements()
+                    .vec();
+  m.get_method(reset_observers_method_name)({});
+  m.get_method(observe_method_name)(inputs);
+  m.get_method(quantize_method_name)(inputs);
+
+  m.compareMethodSchemas(method_name, quantized_method_name);
+  m.unsafeRemoveMethod(method_name);
+  const Function& to_be_copied =
+      m.find_method(quantized_method_name).value().function();
+  m.unsafeCopyMethod(method_name, to_be_copied);
+  m.unsafeRemoveMethod(quantized_method_name);
+  m.unsafeRemoveMethod(quantize_method_name);
+  m.unsafeRemoveMethod(observe_method_name);
+  m.unsafeRemoveMethod(reset_observers_method_name);
+}
+} // namespace quantization
+} // namespace mobile
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/mobile/quantization.h b/torch/csrc/jit/mobile/quantization.h
new file mode 100644
index 0000000..aa47dcb
--- /dev/null
+++ b/torch/csrc/jit/mobile/quantization.h
@@ -0,0 +1,38 @@
+#pragma once
+
+#include <c10/macros/Export.h>
+#include <string>
+
+namespace torch {
+namespace jit {
+namespace mobile {
+class Module;
+namespace quantization {
+/*
+ * Device side PTQ API.
+ * Once the model has been prepared for quantization on server side, such model
+ * is sent to device. On device side the model is further trained. At the end of
+ * the training, before the model is readied for inference, we need to quantize
+ * the model.
+ * Usage of this API is as follows.
+ * PTQQuanizationHelper ptq_helper;
+ * ptq_helper.quantize_dynamic(m, "forward");
+ * Args:
+ * m: Captured by reference, an instance of mobile::Module. This module will be
+ * mutated in place to replace its <method_name> method with quantized
+ * equivalent. method:name: Name of the method to be quantized. AOT preparation
+ * for quantization must also have been done for this method. Returns: In place
+ * mutated `m` whose size should be smaller due to weight quantization and whose
+ * <method_name> method should use quantized ops
+ */
+class TORCH_API PTQQuanizationHelper {
+ public:
+  PTQQuanizationHelper() = default;
+  void quantize_dynamic(
+      torch::jit::mobile::Module& m,
+      const std::string& method_name);
+};
+} // namespace quantization
+} // namespace mobile
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/python/script_init.cpp b/torch/csrc/jit/python/script_init.cpp
index 3f825d9..110c2f4 100644
--- a/torch/csrc/jit/python/script_init.cpp
+++ b/torch/csrc/jit/python/script_init.cpp
@@ -16,6 +16,7 @@
 #include <torch/csrc/jit/mobile/file_format.h>
 #include <torch/csrc/jit/mobile/import.h>
 #include <torch/csrc/jit/mobile/module.h>
+#include <torch/csrc/jit/mobile/quantization.h>
 #include <torch/csrc/jit/operator_upgraders/upgraders.h>
 #include <torch/csrc/jit/operator_upgraders/upgraders_entry.h>
 #include <torch/csrc/jit/operator_upgraders/utils.h>
@@ -1953,6 +1954,12 @@
   m.def("_export_operator_list", [](torch::jit::mobile::Module& sm) {
     return debugMakeSet(torch::jit::mobile::_export_operator_list(sm));
   });
+  m.def(
+      "_quantize_ondevice_ptq_dynamic",
+      [](mobile::Module& m, const std::string& method_name) {
+        mobile::quantization::PTQQuanizationHelper ptq_helper;
+        ptq_helper.quantize_dynamic(m, method_name);
+      });
 
   m.def("_jit_set_emit_hooks", setEmitHooks);
   m.def("_jit_get_emit_hooks", getEmitHooks);