Revert "[Quant] [PT2] Enable batchnorm in _move_exported_model_to_eval (#114547)"
This reverts commit bab054063c7fd6c4b3b8d55a932f2e7fa0a057bb.
Reverted https://github.com/pytorch/pytorch/pull/114547 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114547#issuecomment-1836612143))
diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py
index 757da17..9d0ec86 100644
--- a/test/inductor/test_mkldnn_pattern_matcher.py
+++ b/test/inductor/test_mkldnn_pattern_matcher.py
@@ -1759,43 +1759,6 @@
check_dynamic=True,
)
- @skipIfNoDynamoSupport
- @skipIfNoONEDNN
- @skipIfRocm
- def test_qat_bn_conv2d(self):
- r"""
- This testcase will quantize a single BN Conv2d module with qat flow.
- """
-
- class M(torch.nn.Module):
- def __init__(
- self,
- ):
- super().__init__()
- self.conv = torch.nn.Conv2d(3, 3, 3)
- self.bn1 = torch.nn.BatchNorm2d(3)
- self.bn2 = torch.nn.BatchNorm2d(3)
-
- def forward(self, x):
- x = self.conv(self.bn1(x))
- return self.bn2(x)
-
- mod = M().train()
- v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
-
- def matcher_check_fn():
- self.assertEqual(
- counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
- )
-
- self._test_common(
- mod,
- (v,),
- check_quantization=True,
- is_qat=True,
- matcher_check_fn=matcher_check_fn,
- )
-
if __name__ == "__main__":
if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index 674e1ed..5e3a12c 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -1648,42 +1648,6 @@
self._test_move_exported_model_to_eval_dropout(inplace=False)
self._test_move_exported_model_to_eval_dropout(inplace=True)
- def test_bn_move_exported_model_to_eval(self):
- class M(torch.nn.Module):
- def __init__(
- self,
- ):
- super().__init__()
- self.bn = torch.nn.BatchNorm2d(3)
- self.conv = torch.nn.Conv2d(3, 3, 3)
-
- def forward(self, x):
- return self.conv(self.bn(x))
-
- m = M().train()
- example_inputs = (
- torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1),
- )
-
- m = capture_pre_autograd_graph(m, example_inputs)
-
- # Assert that bn op exists and is in train mode
- batch_norm_node = None
- for n in m.graph.nodes:
- if n.target == torch.ops.aten._native_batch_norm_legit.default:
- batch_norm_node = n
- break
- self.assertTrue(batch_norm_node is not None)
- self.assertTrue(batch_norm_node.args[5])
-
- # Do the subgraph rewriting
- torch.ao.quantization.move_exported_model_to_eval(m)
-
- # Assert that bn op is now in eval mode
- targets = [n.target for n in m.graph.nodes]
- self.assertTrue(torch.ops.aten._native_batch_norm_legit.default not in targets)
- self.assertTrue(torch.ops.aten._native_batch_norm_legit_no_training.default in targets)
-
def test_disallow_eval_train(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
example_inputs = (torch.rand(3, 3, 5, 5),)
diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py
index 996824c..de6e3d6 100644
--- a/test/quantization/pt2e/test_quantize_pt2e_qat.py
+++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py
@@ -159,8 +159,8 @@
self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
if verify_convert:
- model_pt2e = convert_pt2e(model_pt2e)
torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
+ model_pt2e = convert_pt2e(model_pt2e)
quant_result_pt2e = model_pt2e(*example_inputs)
model_fx.eval()
model_fx = _convert_to_reference_decomposed_fx(
diff --git a/torch/ao/quantization/pt2e/eval_utils.py b/torch/ao/quantization/pt2e/eval_utils.py
index 7699e61..c4874ed 100644
--- a/torch/ao/quantization/pt2e/eval_utils.py
+++ b/torch/ao/quantization/pt2e/eval_utils.py
@@ -45,68 +45,14 @@
m.recompile()
-def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):
- # TODO(Leslie): This function still fails to support custom momentum and eps value.
- # Enable this support in future updates.
-
- # Avoid circular dependencies
- from .utils import get_aten_graph_module
-
- # Needed to ensure subgraph matches are self-contained
- m.graph.eliminate_dead_code()
- m.recompile()
-
- def bn_train(
- x: torch.Tensor,
- bn_weight: torch.Tensor,
- bn_bias: torch.Tensor,
- bn_running_mean: torch.Tensor,
- bn_running_var: torch.Tensor,
- ):
- return F.batch_norm(
- x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
- )
-
- def bn_eval(
- x: torch.Tensor,
- bn_weight: torch.Tensor,
- bn_bias: torch.Tensor,
- bn_running_mean: torch.Tensor,
- bn_running_var: torch.Tensor,
- ):
- return F.batch_norm(
- x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
- )
-
- example_inputs = (
- torch.randn(1, 1, 3, 3), # x
- torch.randn(1), # bn_weight
- torch.randn(1), # bn_bias
- torch.randn(1), # bn_running_mean
- torch.randn(1), # bn_running_var
- )
- match_pattern = get_aten_graph_module(bn_train, example_inputs)
- replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
- from torch.fx.subgraph_rewriter import replace_pattern_with_filters
-
- replace_pattern_with_filters(
- m,
- match_pattern,
- replacement_pattern,
- match_filters=[],
- ignore_literals=True,
- )
- m.recompile()
-
-
# TODO: also support move_exported_model_to_train
+# TODO: also support standalone batchnorm
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
"""
Move an exported GraphModule to eval mode.
- This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
+ This is equivalent to model.eval() but only for certain special ops like dropout.
QAT users should call this before performing inference on the model.
"""
_replace_dropout_for_eval(model)
- _replace_batchnorm_for_eval(model)
return model