[quant][graphmode][refactor] Remove unused code in quantization.cpp (#37974)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37974

Differential Revision: D21468498

Pulled By: jerryzh168

fbshipit-source-id: 96f34db9f98474ec8e5d33e9b7c406b1637f5de8
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 783bcfa..2db5382 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -12,30 +12,6 @@
                        "instead.")
 
 class TestFreezing(JitTestCase):
-    def test_fold_quantize_freeze(self):
-        class M(nn.Module):
-            def __init__(self):
-                super(M, self).__init__()
-                self.weight = nn.Parameter(torch.tensor([2], dtype=torch.float))
-
-            def forward(self, x):
-                return torch.quantize_per_tensor(self.weight, 2.0, 0, torch.quint8)
-
-        m = torch.jit.script(M())
-        m.eval()
-        torch._C._jit_pass_fold_quantize(m._c, 'forward')
-        m._c = torch._C._freeze_module(m._c)
-        self.assertFalse(m._c.hasattr('_quantized_weight'))
-        FileCheck().check_not('GetAttr[name=') \
-                   .run(m._c._get_method('forward').graph)
-        buffer = io.BytesIO()
-        torch.jit.save(m, buffer)
-        buffer.seek(0)
-        m_l = torch.jit.load(buffer)
-        self.assertFalse(m_l._c.hasattr('_quantized_weight'))
-        FileCheck().check_not('GetAttr[name=') \
-                   .run(m_l._c._get_method('forward').graph)
-
     def test_freeze_module(self):
         class M(nn.Module):
             def __init__(self):
diff --git a/test/quantization/test_quantize_script.py b/test/quantization/test_quantize_script.py
index d0507b9..b8e612c 100644
--- a/test/quantization/test_quantize_script.py
+++ b/test/quantization/test_quantize_script.py
@@ -11,15 +11,11 @@
 from torch.quantization import default_dynamic_qconfig
 from torch.quantization import QConfigDynamic
 from torch.quantization import default_observer
-from torch.quantization import default_weight_observer
 from torch.quantization import default_per_channel_weight_observer
 from torch.quantization import default_qconfig
 from torch.quantization import get_default_qconfig
-from torch.quantization import quantize
 
 # torch.quantization._quantize_script
-from torch.nn.quantized.modules.linear import LinearPackedParams
-from torch.quantization._quantize_script import ConvPackedParams
 from torch.quantization._quantize_script import script_qconfig
 from torch.quantization._quantize_script import prepare_script
 from torch.quantization._quantize_script import convert_script
@@ -28,8 +24,6 @@
 from torch.quantization._quantize_script import quantize_dynamic_script
 
 # Testing utils
-from torch.testing._internal.common_quantization import SingleLayerLinearModel, AnnotatedSingleLayerLinearModel
-from torch.testing._internal.common_quantization import ConvModel, AnnotatedConvModel
 from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
 
 from torch.testing import FileCheck
@@ -793,105 +787,6 @@
                            .check("return") \
                            .run(conv._get_method('_conv_forward').graph)
 
-    def test_fold_quantize(self):
-        class M(torch.nn.Module):
-            def __init__(self):
-                super(M, self).__init__()
-                self.weight = torch.nn.Parameter(torch.tensor([2], dtype=torch.float))
-
-            def forward(self, x):
-                return torch.quantize_per_tensor(self.weight, 2.0, 0, torch.quint8)
-
-        m = torch.jit.script(M())
-        torch._C._jit_pass_fold_quantize(m._c, 'forward')
-        self.assertTrue(m._c.hasattr('_quantized_weight'))
-        FileCheck().check_not('GetAttr[name="weight"]') \
-                   .check('GetAttr[name="_quantized_weight"]') \
-                   .run(m._c._get_method('forward').graph)
-
-    @unittest.skipUnless(
-        'fbgemm' in torch.backends.quantized.supported_engines,
-        " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
-        " with instruction set support avx2 or newer.",
-    )
-    @unittest.skip("Skip for now since we changed scale/zero_point to attributes."
-                   "We'll enable this in a separate PR")
-    def test_fold_prepack(self):
-        def copy_weights(name, m, ref_m):
-            if name == 'linear':
-                m.fc1.weight = torch.nn.Parameter(ref_m.fc1.module.weight.detach())
-                m.fc1.bias = torch.nn.Parameter(ref_m.fc1.module.bias.detach())
-            else:
-                m.conv.weight = torch.nn.Parameter(ref_m.conv.weight.detach())
-
-        for is_per_channel in [True, False]:
-            for name, M, ref_M, data in [
-                    ('linear',
-                     SingleLayerLinearModel,
-                     AnnotatedSingleLayerLinearModel,
-                     torch.randn((5, 5), dtype=torch.float)),
-                    ('conv',
-                     ConvModel,
-                     AnnotatedConvModel,
-                     torch.randn((1, 3, 7, 7), dtype=torch.float))]:
-                qconfig = QConfig(
-                    activation=default_observer,
-                    weight=default_per_channel_weight_observer if is_per_channel else default_weight_observer)
-                # eager mode
-                ref_m = ref_M()
-                m = M()
-                copy_weights(name, m, ref_m)
-                ref_m.qconfig = qconfig
-                ref_m = quantize(ref_m, _test_only_eval_fn, [(data, torch.randint(0, 1, (5,), dtype=torch.long))])
-                ref_res = ref_m(data)
-                # script mode
-                m = torch.jit.script(m)
-                qconfig_dict = {
-                    '': script_qconfig(qconfig)
-                }
-                m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, 'forward', qconfig_dict, False))
-                get_forward(m._c)(data)
-                m = wrap_cpp_module(torch._C._jit_pass_insert_quant_dequant(m._c, 'forward', False))
-                torch._C._jit_pass_insert_prepack_unpack(m._c)
-                linear_packed_params = torch.jit.script(LinearPackedParams())._c
-                conv_packed_params = torch.jit.script(ConvPackedParams())._c
-                torch._C._jit_pass_fold_prepack(m._c,
-                                                linear_packed_params,
-                                                conv_packed_params)
-                res = get_forward(m._c)(data)
-                # check result
-                self.assertEqual(res, ref_res)
-
-                # check attributes
-                # construct a RecursiveScriptModule
-                m = wrap_cpp_module(m._c)
-                mod_to_inspect = m.fc1 if name == 'linear' else m.conv
-                packed_module_list = [x for x, _ in mod_to_inspect._modules._c.items()
-                                      if x.startswith('_' + name + '_packed_params_module')]
-                assert len(packed_module_list) == 1, \
-                    'Expected to have one packed_params_module'
-                packed_module_name = packed_module_list[0]
-                # check values
-                w, _ = mod_to_inspect._c.getattr(packed_module_name)._get_method('_weight_bias')()
-                original_w = mod_to_inspect.weight
-                if is_per_channel:
-                    ref_w = torch.quantize_per_channel(original_w,
-                                                       w.q_per_channel_scales(),
-                                                       w.q_per_channel_zero_points(),
-                                                       w.q_per_channel_axis(),
-                                                       w.dtype).dequantize()
-                else:
-                    ref_w = torch.quantize_per_tensor(original_w, w.q_scale(), w.q_zero_point(), w.dtype).dequantize()
-                self.assertEqual(ref_w, w.dequantize())
-
-                # test serialization
-                buffer = io.BytesIO()
-                torch.jit.save(m, buffer)
-                buffer.seek(0)
-                loaded_mod = torch.jit.load(buffer)
-                loaded_res = loaded_mod(data)
-                self.assertEqual(ref_res, loaded_res)
-
     def test_dedup_module_uses(self):
         class M(torch.nn.Module):
             def __init__(self):
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 6f125e9..7bc1424 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -30,7 +30,6 @@
 using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
 using NameModuleVector = std::vector<std::pair<std::string, Module>>;
 using graph_rewrite_helper::getFuncName;
-using graph_rewrite_helper::getIValue;
 using graph_rewrite_helper::getValue;
 using graph_rewrite_helper::PatternInfo;
 using graph_rewrite_helper::replaceConvolutionWithAtenConv;
@@ -40,13 +39,6 @@
 // _scalar_type and _axis(for per channel quantization)
 using QParamVector = std::vector<std::pair<std::string, IValue>>;
 
-struct PatternsAndModules {
-  bool is_conv;
-  bool is_per_channel;
-  const PatternInfo& pattern;
-  Module packed_params_module;
-};
-
 std::vector<std::string> _static_quantizable_call_funcs = {
     "conv2d",
     "linear",
@@ -1187,7 +1179,6 @@
   return nodeQuantizable(use.user, is_dynamic);
 }
 
-// TODO: remove this as a class method
 bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
   if (isBiasOfConvOrLinear(v) ||
       !(v->type()->isSubtypeOf(TensorType::get()) ||
@@ -2230,26 +2221,6 @@
   }
 }
 
-c10::optional<IValue> toTwoElementIntList(Value* v) {
-  auto* n = v->node();
-  if (n->kind() == prim::Constant) {
-    auto iv = toIValue(v);
-    if (iv && iv.value().isIntList() && iv.value().toIntList().size() == 2) {
-      return iv;
-    }
-  }
-
-  if (n->kind() == prim::ListConstruct && n->inputs().size() == 2) {
-    auto e0 = toIValue(n->inputs()[0]);
-    auto e1 = toIValue(n->inputs()[1]);
-    if (!e0 || !e1 || !e0.value().isInt() || !e1.value().isInt()) {
-      return c10::nullopt;
-    }
-    return IValue(c10::List<int64_t>({e0.value().toInt(), e1.value().toInt()}));
-  }
-  return c10::nullopt;
-}
-
 // A helper class to make uses of module unique
 class ModuleUseDeduper {
  public:
@@ -2851,10 +2822,6 @@
   return module;
 }
 
-void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
-  throw std::runtime_error("Pass not implemented yet!");
-}
-
 void SwapFunctionalLinear(Module& module) {
   for (auto& method : module.get_methods()) {
     std::shared_ptr<Graph> g = method.graph();
@@ -3004,62 +2971,6 @@
   return m;
 }
 
-void FoldQuantizeCallIntoBuffer(
-    Module& module,
-    const std::string& method_name) {
-  const PatternInfo& pattern = PatternInfo::parse_from_str(R"(
-graph(%self, %scale, %zero_point, %dtype):
-   %weight = prim::GetAttr[name="weight"](%self)
-   %weight_quant = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
-   return (%weight_quant) )");
-  const Graph& pattern_graph = *pattern.pattern_graph;
-  const auto& vmap = pattern.vmap;
-
-  auto method = module.get_method(method_name);
-  auto graph = method.graph();
-  const auto& matches = findPatternMatches(pattern_graph, *graph);
-  // Extra filter on scale/zero_point/dtype to make sure they are Constant
-  auto filter = [](const Match& match,
-                   const std::unordered_map<std::string, Value*>& vmap) {
-    const auto& match_vmap = match.values_map;
-    auto scale_node = match_vmap.at(vmap.at("scale"))->node();
-    auto zero_point_node = match_vmap.at(vmap.at("zero_point"))->node();
-    auto dtype_node = match_vmap.at(vmap.at("dtype"))->node();
-    return scale_node->kind() == prim::Constant &&
-        zero_point_node->kind() == prim::Constant &&
-        dtype_node->kind() == prim::Constant;
-  };
-  std::unordered_set<Node*> nodes_to_delete;
-  for (const auto& match : matches) {
-    if (!filter(match, vmap)) {
-      continue;
-    }
-    auto match_vmap = match.values_map;
-    auto float_weight = module.attr("weight").toTensor().data();
-    auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble();
-    auto zero_point =
-        toIValue(match_vmap.at(vmap.at("zero_point"))).value().toInt();
-    auto dtype =
-        toIValue(match_vmap.at(vmap.at("dtype"))).value().toScalarType();
-    module.register_buffer(
-        "_quantized_weight",
-        at::quantize_per_tensor(float_weight, scale, zero_point, dtype));
-
-    // Replace the GetAttr[weight]->quantize_per_tensor sequence
-    // with a simple GetAttr[_quantized_weight] node.
-    Value* orig_weight = match_vmap.at(vmap.at("weight"));
-    Value* orig_weight_quant = match_vmap.at(vmap.at("weight_quant"));
-
-    orig_weight->node()->s_(attr::name, "_quantized_weight");
-    orig_weight_quant->replaceAllUsesWith(orig_weight);
-    nodes_to_delete.insert(orig_weight_quant->node());
-  }
-
-  for (Node* n : nodes_to_delete) {
-    n->destroy();
-  }
-}
-
 void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) {
   insertPrepackUnpackForLinear(graph);
   insertPrepackUnpackForConv(graph);
@@ -3075,192 +2986,6 @@
   }
 }
 
-struct FoldPrepackedWeightIntoModuleHelper {
-  void run(
-      Module& module,
-      const std::string& method_name,
-      const Module& linear_params_module,
-      const Module& conv_params_module) {
-    auto method = module.get_method(method_name);
-    auto graph = method.graph();
-    GRAPH_DUMP("Before FoldPrepackWeightIntoModule: ", graph);
-
-    // (is_conv, is_per_channel, pattern, packed_params_module)
-    std::vector<PatternsAndModules> pattern_and_modules = {
-        {false, false, linear_prepack_per_tensor, linear_params_module},
-        {false, true, linear_prepack_per_channel, linear_params_module},
-        {true, false, conv2d_prepack, conv_params_module},
-        {true, true, conv2d_prepack_per_channel, conv_params_module}};
-    for (const auto& pm : pattern_and_modules) {
-      const Graph& pattern_graph = *pm.pattern.pattern_graph;
-      const auto& vmap = pm.pattern.vmap;
-      const auto& matches = findPatternMatches(pattern_graph, *graph);
-      TORCH_INTERNAL_ASSERT(
-          matches.size() <= 1, "We only support at most one match right now");
-      for (const auto& match : matches) {
-        const auto& match_vmap = match.values_map;
-        auto w_dtype_opt = getIValue("w_dtype", match_vmap, vmap);
-        auto w_scale_opt = getIValue("w_scale", match_vmap, vmap);
-        auto w_zero_point_opt = getIValue("w_zero_point", match_vmap, vmap);
-        if (!w_dtype_opt || !w_scale_opt || !w_zero_point_opt) {
-          GRAPH_DEBUG(
-              "dtype, scale or zero_point for weight(",
-              getValue("w_dtype", match_vmap, vmap)->debugName(),
-              ", ",
-              getValue("w_scale", match_vmap, vmap)->debugName(),
-              ", ",
-              getValue("w_zero_point", match_vmap, vmap)->debugName(),
-              ") is not constant, skipping the match.");
-          continue;
-        }
-        auto w_dtype = w_dtype_opt.value().toScalarType();
-        auto w = module.attr("weight").toTensor().data();
-        at::Tensor w_quant;
-        if (pm.is_per_channel) {
-          auto w_axis_opt = getIValue("w_axis", match_vmap, vmap);
-          if (!w_axis_opt) {
-            GRAPH_DEBUG(
-                "axis for weight ",
-                getValue("w_axis", match_vmap, vmap)->debugName(),
-                " is non-constant, skipping the match");
-            continue;
-          }
-          auto w_scale = w_scale_opt.value().toTensor().to(at::kFloat);
-          auto w_zero_point = w_zero_point_opt.value().toTensor().to(at::kInt);
-          int w_axis = w_axis_opt.value().toInt();
-          TORCH_CHECK(
-              w_scale.sizes() == w_zero_point.sizes(),
-              "scale and zero_point must have the same size");
-          w_quant = at::quantize_per_channel(
-              w, w_scale, w_zero_point, w_axis, w_dtype);
-        } else {
-          auto w_scale = w_scale_opt.value().toDouble();
-          auto w_zero_point = w_zero_point_opt.value().toInt();
-          w_quant = at::quantize_per_tensor(w, w_scale, w_zero_point, w_dtype);
-        }
-        c10::optional<at::Tensor> b = c10::nullopt;
-        if (hastensor(module, "bias")) {
-          b = module.attr("bias").toTensor().data();
-        }
-        Module wrapper_module = pm.packed_params_module.clone();
-        auto set_weight_bias = wrapper_module.get_method("set_weight_bias");
-        std::string module_name_prefix;
-        if (pm.is_conv) {
-          module_name_prefix = "_conv_packed_params_module_for_";
-          auto stride_opt =
-              toTwoElementIntList(getValue("stride", match_vmap, vmap));
-          auto padding_opt =
-              toTwoElementIntList(getValue("padding", match_vmap, vmap));
-          auto dilation_opt =
-              toTwoElementIntList(getValue("dilation", match_vmap, vmap));
-          auto groups_opt = getIValue("groups", match_vmap, vmap);
-          auto set_conv_params = wrapper_module.get_method("set_conv_params");
-          if (!stride_opt || !padding_opt || !dilation_opt) {
-            GRAPH_DEBUG(
-                "Failed to extract two element IntList for stride/padding/dilation, (",
-                getValue("stride", match_vmap, vmap)->debugName(),
-                ", ",
-                getValue("padding", match_vmap, vmap)->debugName(),
-                ", ",
-                getValue("dilation", match_vmap, vmap)->debugName(),
-                ") skipping the match");
-            continue;
-          }
-          set_conv_params(std::vector<IValue>{stride_opt.value(),
-                                              padding_opt.value(),
-                                              dilation_opt.value(),
-                                              groups_opt.value()});
-        } else {
-          module_name_prefix = "_linear_packed_params_module_for_";
-        }
-        set_weight_bias(std::vector<IValue>{IValue(w_quant), IValue(b)});
-        auto w_quant_val = getValue("w_quant", match_vmap, vmap);
-        // unique name for the module based on %w_quant
-        int uid = 0;
-        auto module_name = module_name_prefix + c10::to_string(uid++);
-        while (module.hasattr(module_name)) {
-          module_name_prefix + c10::to_string(uid++);
-        }
-        GRAPH_UPDATE("Adding new module: ", module_name);
-        module.register_module(module_name, wrapper_module);
-
-        // Add GetAttr of the packed module
-        auto packed_params_val = getValue("packed_params", match_vmap, vmap);
-        WithInsertPoint ins(packed_params_val->node());
-        // wrapper_module =
-        // self.{_conv,_linear}_packed_params_module_for_{unique_id}
-        Value* packed_params_module =
-            graph->insertGetAttr(graph->inputs()[0], module_name)
-                ->setType(wrapper_module.type());
-        GRAPH_UPDATE("Adding GetAttr node for the wrapper module");
-
-        // packed_params = wrapper_module._packed_params
-        Value* packed_params_from_attr =
-            graph->insertGetAttr(packed_params_module, "_packed_params");
-        GRAPH_UPDATE(
-            "Adding GetAttr node for _packed_params: ",
-            packed_params_from_attr->debugName());
-        packed_params_val->replaceAllUsesWith(packed_params_from_attr);
-
-        // Delete nodes
-        std::vector<Node*> nodes_to_delete = {w_quant_val->node(),
-                                              packed_params_val->node()};
-        for (auto n : nodes_to_delete) {
-          n->removeAllInputs();
-        }
-        for (auto n : nodes_to_delete) {
-          GRAPH_UPDATE("Deleting node: ", n);
-          n->destroy();
-        }
-      }
-    }
-  }
-
-  void run(
-      Module& module,
-      const Module& linear_params_module,
-      const Module& conv_params_module) {
-    for (auto& method : module.get_methods()) {
-      run(module, method.name(), linear_params_module, conv_params_module);
-    }
-    for (Module m : module.children()) {
-      run(m, linear_params_module, conv_params_module);
-    }
-  }
-
-  const PatternInfo linear_prepack_per_tensor = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_dtype):
-        %w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
-        %packed_params = quantized::linear_prepack(%w_quant, %b)
-        return (%packed_params) )");
-
-  const PatternInfo linear_prepack_per_channel = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_axis, %w_dtype):
-        %w_quant = aten::quantize_per_channel(%w, %w_scale, %w_zero_point, %w_axis, %w_dtype)
-        %packed_params = quantized::linear_prepack(%w_quant, %b)
-        return (%packed_params) )");
-
-  const PatternInfo conv2d_prepack = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_dtype, %stride, %padding, %dilation, %groups):
-        %w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
-        %packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
-        return (%packed_params) )");
-
-  const PatternInfo conv2d_prepack_per_channel = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_axis, %w_dtype, %stride, %padding, %dilation, %groups):
-        %w_quant = aten::quantize_per_channel(%w, %w_scale, %w_zero_point, %w_axis, %w_dtype)
-        %packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
-        return (%packed_params) )");
-};
-
-void FoldPrepackedWeightIntoModule(
-    Module& module,
-    const Module& linear_params_module,
-    const Module& conv_params_module) {
-  FoldPrepackedWeightIntoModuleHelper h;
-  h.run(module, linear_params_module, conv_params_module);
-}
-
 void DedupModuleUses(Module& module) {
   ModuleUseDeduper d(module);
   d.dedup();
diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h
index 9a2381a..6b2732d 100644
--- a/torch/csrc/jit/passes/quantization.h
+++ b/torch/csrc/jit/passes/quantization.h
@@ -40,13 +40,6 @@
 using QConfigTypePtrMap =
     std::unordered_map<c10::optional<QConfig>, TypePtr, OptionalQConfigHash>;
 
-/** \brief Quantize model's inputs and outputs.
- *
- * This pass folds quant/dequant ops into the input/output tensors, essentially
- * quantizing these tensors. It's done to reduce model's memory footprint.
- */
-TORCH_API void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph);
-
 /** \brief Insert observer module and observer function call for
  *  the Tensors that needs to be observed.
  *
@@ -144,19 +137,6 @@
  */
 TORCH_API Module FoldConvBatchNorm2d(const Module& module);
 
-/** \brief Fold quantize function call into module
- *
- *  For the graph of the specified method of module, if we find a
- * quantize_per_tensor call on an attribute("weight") of the module, we'll
- * quantize the attribute directly and register a new buffer "_quantized_weight"
- * on the module and remove the quantize_per_tensor call and replace the use of
- * the quantized weight with
- *  "_quantized_weight".
- */
-TORCH_API void FoldQuantizeCallIntoBuffer(
-    Module& module,
-    const std::string& method_name);
-
 /** \brief Insert prepack and unpack function in graph
  *  We want add pack/unpack functions for quantized weight because later we want
  * to fold the packed weight as an attribute of the module, in order to reduce
@@ -176,25 +156,6 @@
  */
 TORCH_API void InsertPrepackUnpack(Module& module);
 
-/** \brief Fold prepack function call into module
- *
- *  For the graph of the specified method, if we find a
- * `quantized::linear_prepack` call, we'll clone the wrapper module and set the
- * weight and bias of the module and add the wrapper module as a child to the
- * input module. Folding is recursively applied to all methods of all child
- * modules of the input module
- *
- *  Wrapper module is used to overwrite serialization for packed
- *  weight and bias since they are not recognized by JIT, this
- *  is a workaround, a long term solution would be to support serialization of
- *  packed weight and bias using custom types
- *
- */
-TORCH_API void FoldPrepackedWeightIntoModule(
-    Module& module,
-    const Module& linear_params_module,
-    const Module& conv_params_module);
-
 /** Recursively deduplicate multiple uses of the same module by
  *  creating an instance clone for each use of the module, which means
  *  the type will be the same as before and all the attributes will be
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 6bfe4ab..58b0fab 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -217,12 +217,6 @@
           [](Module& module) { return freeze_module(module); },
           py::arg("module"))
       .def("_jit_pass_fuse_linear", &FuseLinear)
-      .def(
-          "_jit_pass_fold_quantize",
-          [](Module& module, const std::string& method_name) {
-            FoldQuantizeCallIntoBuffer(module, method_name);
-          })
-      .def("_jit_pass_fold_prepack", &FoldPrepackedWeightIntoModule)
       .def("_jit_pass_dedup_module_uses", &DedupModuleUses)
       .def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
       .def("_jit_pass_swap_dequantize", &PropagateQuantizationOps)
@@ -261,11 +255,6 @@
             subgraph_rewriter.runOnGraph(g);
           })
       .def(
-          "_jit_pass_fold_quant_inputs",
-          [](std::shared_ptr<Graph>& g) {
-            return FoldQuantNodesIntoInputsOutputs(g);
-          })
-      .def(
           "_jit_pass_remove_inplace_ops",
           [](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
       .def("_jit_pass_constant_pooling", ConstantPooling)
diff --git a/torch/quantization/_quantize_script.py b/torch/quantization/_quantize_script.py
index 4761864..467081ad 100644
--- a/torch/quantization/_quantize_script.py
+++ b/torch/quantization/_quantize_script.py
@@ -1,69 +1,9 @@
 from __future__ import absolute_import, division, print_function, unicode_literals
 
-from typing import List, Optional
-
 import torch
 from .qconfig import QConfig
 from torch.jit._recursive import wrap_cpp_module
 
-class ConvPackedParams(torch.nn.Module):
-    def __init__(self):
-        super(ConvPackedParams, self).__init__()
-        wq = torch._empty_affine_quantized([1, 1, 1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
-        self.stride = [1, 1]
-        self.padding = [0, 0]
-        self.dilation = [1, 1]
-        self.groups = 1
-        self.set_weight_bias(wq, None)
-
-    @torch.jit.export
-    def set_conv_params(self, stride, padding, dilation, groups):
-        # type: (List[int], List[int], List[int], int) -> None
-        self.stride = stride
-        self.padding = padding
-        self.dilation = dilation
-        self.groups = groups
-
-    @torch.jit.export
-    def set_weight_bias(self, weight, bias):
-        # type: (torch.Tensor, Optional[torch.Tensor]) -> None
-        self._packed_params = torch.ops.quantized.conv2d_prepack(weight, bias, self.stride,
-                                                                 self.padding, self.dilation, self.groups)
-
-    @torch.jit.export
-    def _weight_bias(self):
-        return torch.ops.quantized.conv2d_unpack(self._packed_params)
-
-    def forward(self, x):
-        return x
-
-    @torch.jit.export
-    def __getstate__(self):
-        qweight, bias = self._weight_bias()
-        return (qweight,
-                bias,
-                self.stride,
-                self.padding,
-                self.dilation,
-                self.groups,
-                self.training)
-
-    @torch.jit.export
-    def __setstate__(self, state):
-        self.stride = state[2]
-        self.padding = state[3]
-        self.dilation = state[4]
-        self.groups = state[5]
-        self.set_weight_bias(state[0],
-                             state[1])
-        self.training = state[6]
-
-linear_packed_params = None
-conv_packed_params = None
-if 'fbgemm' in torch.backends.quantized.supported_engines:
-    linear_packed_params = torch.jit.script(torch.nn.quantized.modules.linear.LinearPackedParams())._c
-    conv_packed_params = torch.jit.script(ConvPackedParams())._c
-
 def _check_is_script_module(model):
     if not isinstance(model, torch.jit.ScriptModule):
         raise ValueError('input must be a script module, got: ' + str(type(model)))
@@ -77,14 +17,14 @@
         activation=torch.jit.script(qconfig.activation())._c,
         weight=torch.jit.script(qconfig.weight())._c)
 
-def get_scripted_qconfig_dict(qconfig_dict):
+def script_qconfig_dict(qconfig_dict):
     return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
 
 def _prepare_script(model, qconfig_dict, is_dynamic):
     _check_is_script_module(model)
     if any(map(lambda x : not isinstance(x, str), qconfig_dict.keys())):
         raise ValueError('qconfig_dict should contain names(str) as keys.')
-    scripted_qconfig_dict = get_scripted_qconfig_dict(qconfig_dict)
+    scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
     return wrap_cpp_module(torch._C._jit_pass_insert_observers(model._c,
                                                                'forward',
                                                                scripted_qconfig_dict,