[quant][graphmode][refactor] Remove unused code in quantization.cpp (#37974)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/37974
Differential Revision: D21468498
Pulled By: jerryzh168
fbshipit-source-id: 96f34db9f98474ec8e5d33e9b7c406b1637f5de8
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 783bcfa..2db5382 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -12,30 +12,6 @@
"instead.")
class TestFreezing(JitTestCase):
- def test_fold_quantize_freeze(self):
- class M(nn.Module):
- def __init__(self):
- super(M, self).__init__()
- self.weight = nn.Parameter(torch.tensor([2], dtype=torch.float))
-
- def forward(self, x):
- return torch.quantize_per_tensor(self.weight, 2.0, 0, torch.quint8)
-
- m = torch.jit.script(M())
- m.eval()
- torch._C._jit_pass_fold_quantize(m._c, 'forward')
- m._c = torch._C._freeze_module(m._c)
- self.assertFalse(m._c.hasattr('_quantized_weight'))
- FileCheck().check_not('GetAttr[name=') \
- .run(m._c._get_method('forward').graph)
- buffer = io.BytesIO()
- torch.jit.save(m, buffer)
- buffer.seek(0)
- m_l = torch.jit.load(buffer)
- self.assertFalse(m_l._c.hasattr('_quantized_weight'))
- FileCheck().check_not('GetAttr[name=') \
- .run(m_l._c._get_method('forward').graph)
-
def test_freeze_module(self):
class M(nn.Module):
def __init__(self):
diff --git a/test/quantization/test_quantize_script.py b/test/quantization/test_quantize_script.py
index d0507b9..b8e612c 100644
--- a/test/quantization/test_quantize_script.py
+++ b/test/quantization/test_quantize_script.py
@@ -11,15 +11,11 @@
from torch.quantization import default_dynamic_qconfig
from torch.quantization import QConfigDynamic
from torch.quantization import default_observer
-from torch.quantization import default_weight_observer
from torch.quantization import default_per_channel_weight_observer
from torch.quantization import default_qconfig
from torch.quantization import get_default_qconfig
-from torch.quantization import quantize
# torch.quantization._quantize_script
-from torch.nn.quantized.modules.linear import LinearPackedParams
-from torch.quantization._quantize_script import ConvPackedParams
from torch.quantization._quantize_script import script_qconfig
from torch.quantization._quantize_script import prepare_script
from torch.quantization._quantize_script import convert_script
@@ -28,8 +24,6 @@
from torch.quantization._quantize_script import quantize_dynamic_script
# Testing utils
-from torch.testing._internal.common_quantization import SingleLayerLinearModel, AnnotatedSingleLayerLinearModel
-from torch.testing._internal.common_quantization import ConvModel, AnnotatedConvModel
from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
from torch.testing import FileCheck
@@ -793,105 +787,6 @@
.check("return") \
.run(conv._get_method('_conv_forward').graph)
- def test_fold_quantize(self):
- class M(torch.nn.Module):
- def __init__(self):
- super(M, self).__init__()
- self.weight = torch.nn.Parameter(torch.tensor([2], dtype=torch.float))
-
- def forward(self, x):
- return torch.quantize_per_tensor(self.weight, 2.0, 0, torch.quint8)
-
- m = torch.jit.script(M())
- torch._C._jit_pass_fold_quantize(m._c, 'forward')
- self.assertTrue(m._c.hasattr('_quantized_weight'))
- FileCheck().check_not('GetAttr[name="weight"]') \
- .check('GetAttr[name="_quantized_weight"]') \
- .run(m._c._get_method('forward').graph)
-
- @unittest.skipUnless(
- 'fbgemm' in torch.backends.quantized.supported_engines,
- " Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
- " with instruction set support avx2 or newer.",
- )
- @unittest.skip("Skip for now since we changed scale/zero_point to attributes."
- "We'll enable this in a separate PR")
- def test_fold_prepack(self):
- def copy_weights(name, m, ref_m):
- if name == 'linear':
- m.fc1.weight = torch.nn.Parameter(ref_m.fc1.module.weight.detach())
- m.fc1.bias = torch.nn.Parameter(ref_m.fc1.module.bias.detach())
- else:
- m.conv.weight = torch.nn.Parameter(ref_m.conv.weight.detach())
-
- for is_per_channel in [True, False]:
- for name, M, ref_M, data in [
- ('linear',
- SingleLayerLinearModel,
- AnnotatedSingleLayerLinearModel,
- torch.randn((5, 5), dtype=torch.float)),
- ('conv',
- ConvModel,
- AnnotatedConvModel,
- torch.randn((1, 3, 7, 7), dtype=torch.float))]:
- qconfig = QConfig(
- activation=default_observer,
- weight=default_per_channel_weight_observer if is_per_channel else default_weight_observer)
- # eager mode
- ref_m = ref_M()
- m = M()
- copy_weights(name, m, ref_m)
- ref_m.qconfig = qconfig
- ref_m = quantize(ref_m, _test_only_eval_fn, [(data, torch.randint(0, 1, (5,), dtype=torch.long))])
- ref_res = ref_m(data)
- # script mode
- m = torch.jit.script(m)
- qconfig_dict = {
- '': script_qconfig(qconfig)
- }
- m = wrap_cpp_module(torch._C._jit_pass_insert_observers(m._c, 'forward', qconfig_dict, False))
- get_forward(m._c)(data)
- m = wrap_cpp_module(torch._C._jit_pass_insert_quant_dequant(m._c, 'forward', False))
- torch._C._jit_pass_insert_prepack_unpack(m._c)
- linear_packed_params = torch.jit.script(LinearPackedParams())._c
- conv_packed_params = torch.jit.script(ConvPackedParams())._c
- torch._C._jit_pass_fold_prepack(m._c,
- linear_packed_params,
- conv_packed_params)
- res = get_forward(m._c)(data)
- # check result
- self.assertEqual(res, ref_res)
-
- # check attributes
- # construct a RecursiveScriptModule
- m = wrap_cpp_module(m._c)
- mod_to_inspect = m.fc1 if name == 'linear' else m.conv
- packed_module_list = [x for x, _ in mod_to_inspect._modules._c.items()
- if x.startswith('_' + name + '_packed_params_module')]
- assert len(packed_module_list) == 1, \
- 'Expected to have one packed_params_module'
- packed_module_name = packed_module_list[0]
- # check values
- w, _ = mod_to_inspect._c.getattr(packed_module_name)._get_method('_weight_bias')()
- original_w = mod_to_inspect.weight
- if is_per_channel:
- ref_w = torch.quantize_per_channel(original_w,
- w.q_per_channel_scales(),
- w.q_per_channel_zero_points(),
- w.q_per_channel_axis(),
- w.dtype).dequantize()
- else:
- ref_w = torch.quantize_per_tensor(original_w, w.q_scale(), w.q_zero_point(), w.dtype).dequantize()
- self.assertEqual(ref_w, w.dequantize())
-
- # test serialization
- buffer = io.BytesIO()
- torch.jit.save(m, buffer)
- buffer.seek(0)
- loaded_mod = torch.jit.load(buffer)
- loaded_res = loaded_mod(data)
- self.assertEqual(ref_res, loaded_res)
-
def test_dedup_module_uses(self):
class M(torch.nn.Module):
def __init__(self):
diff --git a/torch/csrc/jit/passes/quantization.cpp b/torch/csrc/jit/passes/quantization.cpp
index 6f125e9..7bc1424 100644
--- a/torch/csrc/jit/passes/quantization.cpp
+++ b/torch/csrc/jit/passes/quantization.cpp
@@ -30,7 +30,6 @@
using ModuleMethodVector = std::vector<std::pair<Module, std::string>>;
using NameModuleVector = std::vector<std::pair<std::string, Module>>;
using graph_rewrite_helper::getFuncName;
-using graph_rewrite_helper::getIValue;
using graph_rewrite_helper::getValue;
using graph_rewrite_helper::PatternInfo;
using graph_rewrite_helper::replaceConvolutionWithAtenConv;
@@ -40,13 +39,6 @@
// _scalar_type and _axis(for per channel quantization)
using QParamVector = std::vector<std::pair<std::string, IValue>>;
-struct PatternsAndModules {
- bool is_conv;
- bool is_per_channel;
- const PatternInfo& pattern;
- Module packed_params_module;
-};
-
std::vector<std::string> _static_quantizable_call_funcs = {
"conv2d",
"linear",
@@ -1187,7 +1179,6 @@
return nodeQuantizable(use.user, is_dynamic);
}
-// TODO: remove this as a class method
bool InsertObserversHelper::valueNeedsToBeQuantized(Value* v) {
if (isBiasOfConvOrLinear(v) ||
!(v->type()->isSubtypeOf(TensorType::get()) ||
@@ -2230,26 +2221,6 @@
}
}
-c10::optional<IValue> toTwoElementIntList(Value* v) {
- auto* n = v->node();
- if (n->kind() == prim::Constant) {
- auto iv = toIValue(v);
- if (iv && iv.value().isIntList() && iv.value().toIntList().size() == 2) {
- return iv;
- }
- }
-
- if (n->kind() == prim::ListConstruct && n->inputs().size() == 2) {
- auto e0 = toIValue(n->inputs()[0]);
- auto e1 = toIValue(n->inputs()[1]);
- if (!e0 || !e1 || !e0.value().isInt() || !e1.value().isInt()) {
- return c10::nullopt;
- }
- return IValue(c10::List<int64_t>({e0.value().toInt(), e1.value().toInt()}));
- }
- return c10::nullopt;
-}
-
// A helper class to make uses of module unique
class ModuleUseDeduper {
public:
@@ -2851,10 +2822,6 @@
return module;
}
-void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph) {
- throw std::runtime_error("Pass not implemented yet!");
-}
-
void SwapFunctionalLinear(Module& module) {
for (auto& method : module.get_methods()) {
std::shared_ptr<Graph> g = method.graph();
@@ -3004,62 +2971,6 @@
return m;
}
-void FoldQuantizeCallIntoBuffer(
- Module& module,
- const std::string& method_name) {
- const PatternInfo& pattern = PatternInfo::parse_from_str(R"(
-graph(%self, %scale, %zero_point, %dtype):
- %weight = prim::GetAttr[name="weight"](%self)
- %weight_quant = aten::quantize_per_tensor(%weight, %scale, %zero_point, %dtype)
- return (%weight_quant) )");
- const Graph& pattern_graph = *pattern.pattern_graph;
- const auto& vmap = pattern.vmap;
-
- auto method = module.get_method(method_name);
- auto graph = method.graph();
- const auto& matches = findPatternMatches(pattern_graph, *graph);
- // Extra filter on scale/zero_point/dtype to make sure they are Constant
- auto filter = [](const Match& match,
- const std::unordered_map<std::string, Value*>& vmap) {
- const auto& match_vmap = match.values_map;
- auto scale_node = match_vmap.at(vmap.at("scale"))->node();
- auto zero_point_node = match_vmap.at(vmap.at("zero_point"))->node();
- auto dtype_node = match_vmap.at(vmap.at("dtype"))->node();
- return scale_node->kind() == prim::Constant &&
- zero_point_node->kind() == prim::Constant &&
- dtype_node->kind() == prim::Constant;
- };
- std::unordered_set<Node*> nodes_to_delete;
- for (const auto& match : matches) {
- if (!filter(match, vmap)) {
- continue;
- }
- auto match_vmap = match.values_map;
- auto float_weight = module.attr("weight").toTensor().data();
- auto scale = toIValue(match_vmap.at(vmap.at("scale"))).value().toDouble();
- auto zero_point =
- toIValue(match_vmap.at(vmap.at("zero_point"))).value().toInt();
- auto dtype =
- toIValue(match_vmap.at(vmap.at("dtype"))).value().toScalarType();
- module.register_buffer(
- "_quantized_weight",
- at::quantize_per_tensor(float_weight, scale, zero_point, dtype));
-
- // Replace the GetAttr[weight]->quantize_per_tensor sequence
- // with a simple GetAttr[_quantized_weight] node.
- Value* orig_weight = match_vmap.at(vmap.at("weight"));
- Value* orig_weight_quant = match_vmap.at(vmap.at("weight_quant"));
-
- orig_weight->node()->s_(attr::name, "_quantized_weight");
- orig_weight_quant->replaceAllUsesWith(orig_weight);
- nodes_to_delete.insert(orig_weight_quant->node());
- }
-
- for (Node* n : nodes_to_delete) {
- n->destroy();
- }
-}
-
void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) {
insertPrepackUnpackForLinear(graph);
insertPrepackUnpackForConv(graph);
@@ -3075,192 +2986,6 @@
}
}
-struct FoldPrepackedWeightIntoModuleHelper {
- void run(
- Module& module,
- const std::string& method_name,
- const Module& linear_params_module,
- const Module& conv_params_module) {
- auto method = module.get_method(method_name);
- auto graph = method.graph();
- GRAPH_DUMP("Before FoldPrepackWeightIntoModule: ", graph);
-
- // (is_conv, is_per_channel, pattern, packed_params_module)
- std::vector<PatternsAndModules> pattern_and_modules = {
- {false, false, linear_prepack_per_tensor, linear_params_module},
- {false, true, linear_prepack_per_channel, linear_params_module},
- {true, false, conv2d_prepack, conv_params_module},
- {true, true, conv2d_prepack_per_channel, conv_params_module}};
- for (const auto& pm : pattern_and_modules) {
- const Graph& pattern_graph = *pm.pattern.pattern_graph;
- const auto& vmap = pm.pattern.vmap;
- const auto& matches = findPatternMatches(pattern_graph, *graph);
- TORCH_INTERNAL_ASSERT(
- matches.size() <= 1, "We only support at most one match right now");
- for (const auto& match : matches) {
- const auto& match_vmap = match.values_map;
- auto w_dtype_opt = getIValue("w_dtype", match_vmap, vmap);
- auto w_scale_opt = getIValue("w_scale", match_vmap, vmap);
- auto w_zero_point_opt = getIValue("w_zero_point", match_vmap, vmap);
- if (!w_dtype_opt || !w_scale_opt || !w_zero_point_opt) {
- GRAPH_DEBUG(
- "dtype, scale or zero_point for weight(",
- getValue("w_dtype", match_vmap, vmap)->debugName(),
- ", ",
- getValue("w_scale", match_vmap, vmap)->debugName(),
- ", ",
- getValue("w_zero_point", match_vmap, vmap)->debugName(),
- ") is not constant, skipping the match.");
- continue;
- }
- auto w_dtype = w_dtype_opt.value().toScalarType();
- auto w = module.attr("weight").toTensor().data();
- at::Tensor w_quant;
- if (pm.is_per_channel) {
- auto w_axis_opt = getIValue("w_axis", match_vmap, vmap);
- if (!w_axis_opt) {
- GRAPH_DEBUG(
- "axis for weight ",
- getValue("w_axis", match_vmap, vmap)->debugName(),
- " is non-constant, skipping the match");
- continue;
- }
- auto w_scale = w_scale_opt.value().toTensor().to(at::kFloat);
- auto w_zero_point = w_zero_point_opt.value().toTensor().to(at::kInt);
- int w_axis = w_axis_opt.value().toInt();
- TORCH_CHECK(
- w_scale.sizes() == w_zero_point.sizes(),
- "scale and zero_point must have the same size");
- w_quant = at::quantize_per_channel(
- w, w_scale, w_zero_point, w_axis, w_dtype);
- } else {
- auto w_scale = w_scale_opt.value().toDouble();
- auto w_zero_point = w_zero_point_opt.value().toInt();
- w_quant = at::quantize_per_tensor(w, w_scale, w_zero_point, w_dtype);
- }
- c10::optional<at::Tensor> b = c10::nullopt;
- if (hastensor(module, "bias")) {
- b = module.attr("bias").toTensor().data();
- }
- Module wrapper_module = pm.packed_params_module.clone();
- auto set_weight_bias = wrapper_module.get_method("set_weight_bias");
- std::string module_name_prefix;
- if (pm.is_conv) {
- module_name_prefix = "_conv_packed_params_module_for_";
- auto stride_opt =
- toTwoElementIntList(getValue("stride", match_vmap, vmap));
- auto padding_opt =
- toTwoElementIntList(getValue("padding", match_vmap, vmap));
- auto dilation_opt =
- toTwoElementIntList(getValue("dilation", match_vmap, vmap));
- auto groups_opt = getIValue("groups", match_vmap, vmap);
- auto set_conv_params = wrapper_module.get_method("set_conv_params");
- if (!stride_opt || !padding_opt || !dilation_opt) {
- GRAPH_DEBUG(
- "Failed to extract two element IntList for stride/padding/dilation, (",
- getValue("stride", match_vmap, vmap)->debugName(),
- ", ",
- getValue("padding", match_vmap, vmap)->debugName(),
- ", ",
- getValue("dilation", match_vmap, vmap)->debugName(),
- ") skipping the match");
- continue;
- }
- set_conv_params(std::vector<IValue>{stride_opt.value(),
- padding_opt.value(),
- dilation_opt.value(),
- groups_opt.value()});
- } else {
- module_name_prefix = "_linear_packed_params_module_for_";
- }
- set_weight_bias(std::vector<IValue>{IValue(w_quant), IValue(b)});
- auto w_quant_val = getValue("w_quant", match_vmap, vmap);
- // unique name for the module based on %w_quant
- int uid = 0;
- auto module_name = module_name_prefix + c10::to_string(uid++);
- while (module.hasattr(module_name)) {
- module_name_prefix + c10::to_string(uid++);
- }
- GRAPH_UPDATE("Adding new module: ", module_name);
- module.register_module(module_name, wrapper_module);
-
- // Add GetAttr of the packed module
- auto packed_params_val = getValue("packed_params", match_vmap, vmap);
- WithInsertPoint ins(packed_params_val->node());
- // wrapper_module =
- // self.{_conv,_linear}_packed_params_module_for_{unique_id}
- Value* packed_params_module =
- graph->insertGetAttr(graph->inputs()[0], module_name)
- ->setType(wrapper_module.type());
- GRAPH_UPDATE("Adding GetAttr node for the wrapper module");
-
- // packed_params = wrapper_module._packed_params
- Value* packed_params_from_attr =
- graph->insertGetAttr(packed_params_module, "_packed_params");
- GRAPH_UPDATE(
- "Adding GetAttr node for _packed_params: ",
- packed_params_from_attr->debugName());
- packed_params_val->replaceAllUsesWith(packed_params_from_attr);
-
- // Delete nodes
- std::vector<Node*> nodes_to_delete = {w_quant_val->node(),
- packed_params_val->node()};
- for (auto n : nodes_to_delete) {
- n->removeAllInputs();
- }
- for (auto n : nodes_to_delete) {
- GRAPH_UPDATE("Deleting node: ", n);
- n->destroy();
- }
- }
- }
- }
-
- void run(
- Module& module,
- const Module& linear_params_module,
- const Module& conv_params_module) {
- for (auto& method : module.get_methods()) {
- run(module, method.name(), linear_params_module, conv_params_module);
- }
- for (Module m : module.children()) {
- run(m, linear_params_module, conv_params_module);
- }
- }
-
- const PatternInfo linear_prepack_per_tensor = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_dtype):
- %w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
- %packed_params = quantized::linear_prepack(%w_quant, %b)
- return (%packed_params) )");
-
- const PatternInfo linear_prepack_per_channel = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_axis, %w_dtype):
- %w_quant = aten::quantize_per_channel(%w, %w_scale, %w_zero_point, %w_axis, %w_dtype)
- %packed_params = quantized::linear_prepack(%w_quant, %b)
- return (%packed_params) )");
-
- const PatternInfo conv2d_prepack = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_dtype, %stride, %padding, %dilation, %groups):
- %w_quant = aten::quantize_per_tensor(%w, %w_scale, %w_zero_point, %w_dtype)
- %packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
- return (%packed_params) )");
-
- const PatternInfo conv2d_prepack_per_channel = PatternInfo::parse_from_str(R"(
-graph(%a_dequant, %w, %b, %w_scale, %w_zero_point, %w_axis, %w_dtype, %stride, %padding, %dilation, %groups):
- %w_quant = aten::quantize_per_channel(%w, %w_scale, %w_zero_point, %w_axis, %w_dtype)
- %packed_params = quantized::conv2d_prepack(%w_quant, %b, %stride, %padding, %dilation, %groups)
- return (%packed_params) )");
-};
-
-void FoldPrepackedWeightIntoModule(
- Module& module,
- const Module& linear_params_module,
- const Module& conv_params_module) {
- FoldPrepackedWeightIntoModuleHelper h;
- h.run(module, linear_params_module, conv_params_module);
-}
-
void DedupModuleUses(Module& module) {
ModuleUseDeduper d(module);
d.dedup();
diff --git a/torch/csrc/jit/passes/quantization.h b/torch/csrc/jit/passes/quantization.h
index 9a2381a..6b2732d 100644
--- a/torch/csrc/jit/passes/quantization.h
+++ b/torch/csrc/jit/passes/quantization.h
@@ -40,13 +40,6 @@
using QConfigTypePtrMap =
std::unordered_map<c10::optional<QConfig>, TypePtr, OptionalQConfigHash>;
-/** \brief Quantize model's inputs and outputs.
- *
- * This pass folds quant/dequant ops into the input/output tensors, essentially
- * quantizing these tensors. It's done to reduce model's memory footprint.
- */
-TORCH_API void FoldQuantNodesIntoInputsOutputs(std::shared_ptr<Graph>& graph);
-
/** \brief Insert observer module and observer function call for
* the Tensors that needs to be observed.
*
@@ -144,19 +137,6 @@
*/
TORCH_API Module FoldConvBatchNorm2d(const Module& module);
-/** \brief Fold quantize function call into module
- *
- * For the graph of the specified method of module, if we find a
- * quantize_per_tensor call on an attribute("weight") of the module, we'll
- * quantize the attribute directly and register a new buffer "_quantized_weight"
- * on the module and remove the quantize_per_tensor call and replace the use of
- * the quantized weight with
- * "_quantized_weight".
- */
-TORCH_API void FoldQuantizeCallIntoBuffer(
- Module& module,
- const std::string& method_name);
-
/** \brief Insert prepack and unpack function in graph
* We want add pack/unpack functions for quantized weight because later we want
* to fold the packed weight as an attribute of the module, in order to reduce
@@ -176,25 +156,6 @@
*/
TORCH_API void InsertPrepackUnpack(Module& module);
-/** \brief Fold prepack function call into module
- *
- * For the graph of the specified method, if we find a
- * `quantized::linear_prepack` call, we'll clone the wrapper module and set the
- * weight and bias of the module and add the wrapper module as a child to the
- * input module. Folding is recursively applied to all methods of all child
- * modules of the input module
- *
- * Wrapper module is used to overwrite serialization for packed
- * weight and bias since they are not recognized by JIT, this
- * is a workaround, a long term solution would be to support serialization of
- * packed weight and bias using custom types
- *
- */
-TORCH_API void FoldPrepackedWeightIntoModule(
- Module& module,
- const Module& linear_params_module,
- const Module& conv_params_module);
-
/** Recursively deduplicate multiple uses of the same module by
* creating an instance clone for each use of the module, which means
* the type will be the same as before and all the attributes will be
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 6bfe4ab..58b0fab 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -217,12 +217,6 @@
[](Module& module) { return freeze_module(module); },
py::arg("module"))
.def("_jit_pass_fuse_linear", &FuseLinear)
- .def(
- "_jit_pass_fold_quantize",
- [](Module& module, const std::string& method_name) {
- FoldQuantizeCallIntoBuffer(module, method_name);
- })
- .def("_jit_pass_fold_prepack", &FoldPrepackedWeightIntoModule)
.def("_jit_pass_dedup_module_uses", &DedupModuleUses)
.def("_jit_pass_replicate_dequantize", &ReplicateDeQuant)
.def("_jit_pass_swap_dequantize", &PropagateQuantizationOps)
@@ -261,11 +255,6 @@
subgraph_rewriter.runOnGraph(g);
})
.def(
- "_jit_pass_fold_quant_inputs",
- [](std::shared_ptr<Graph>& g) {
- return FoldQuantNodesIntoInputsOutputs(g);
- })
- .def(
"_jit_pass_remove_inplace_ops",
[](std::shared_ptr<Graph> g) { return RemoveInplaceOps(g); })
.def("_jit_pass_constant_pooling", ConstantPooling)
diff --git a/torch/quantization/_quantize_script.py b/torch/quantization/_quantize_script.py
index 4761864..467081ad 100644
--- a/torch/quantization/_quantize_script.py
+++ b/torch/quantization/_quantize_script.py
@@ -1,69 +1,9 @@
from __future__ import absolute_import, division, print_function, unicode_literals
-from typing import List, Optional
-
import torch
from .qconfig import QConfig
from torch.jit._recursive import wrap_cpp_module
-class ConvPackedParams(torch.nn.Module):
- def __init__(self):
- super(ConvPackedParams, self).__init__()
- wq = torch._empty_affine_quantized([1, 1, 1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
- self.stride = [1, 1]
- self.padding = [0, 0]
- self.dilation = [1, 1]
- self.groups = 1
- self.set_weight_bias(wq, None)
-
- @torch.jit.export
- def set_conv_params(self, stride, padding, dilation, groups):
- # type: (List[int], List[int], List[int], int) -> None
- self.stride = stride
- self.padding = padding
- self.dilation = dilation
- self.groups = groups
-
- @torch.jit.export
- def set_weight_bias(self, weight, bias):
- # type: (torch.Tensor, Optional[torch.Tensor]) -> None
- self._packed_params = torch.ops.quantized.conv2d_prepack(weight, bias, self.stride,
- self.padding, self.dilation, self.groups)
-
- @torch.jit.export
- def _weight_bias(self):
- return torch.ops.quantized.conv2d_unpack(self._packed_params)
-
- def forward(self, x):
- return x
-
- @torch.jit.export
- def __getstate__(self):
- qweight, bias = self._weight_bias()
- return (qweight,
- bias,
- self.stride,
- self.padding,
- self.dilation,
- self.groups,
- self.training)
-
- @torch.jit.export
- def __setstate__(self, state):
- self.stride = state[2]
- self.padding = state[3]
- self.dilation = state[4]
- self.groups = state[5]
- self.set_weight_bias(state[0],
- state[1])
- self.training = state[6]
-
-linear_packed_params = None
-conv_packed_params = None
-if 'fbgemm' in torch.backends.quantized.supported_engines:
- linear_packed_params = torch.jit.script(torch.nn.quantized.modules.linear.LinearPackedParams())._c
- conv_packed_params = torch.jit.script(ConvPackedParams())._c
-
def _check_is_script_module(model):
if not isinstance(model, torch.jit.ScriptModule):
raise ValueError('input must be a script module, got: ' + str(type(model)))
@@ -77,14 +17,14 @@
activation=torch.jit.script(qconfig.activation())._c,
weight=torch.jit.script(qconfig.weight())._c)
-def get_scripted_qconfig_dict(qconfig_dict):
+def script_qconfig_dict(qconfig_dict):
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
def _prepare_script(model, qconfig_dict, is_dynamic):
_check_is_script_module(model)
if any(map(lambda x : not isinstance(x, str), qconfig_dict.keys())):
raise ValueError('qconfig_dict should contain names(str) as keys.')
- scripted_qconfig_dict = get_scripted_qconfig_dict(qconfig_dict)
+ scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
return wrap_cpp_module(torch._C._jit_pass_insert_observers(model._c,
'forward',
scripted_qconfig_dict,