Revert "[Quant] [PT2] Enable batchnorm in _move_exported_model_to_eval (#114547)"

This reverts commit bab054063c7fd6c4b3b8d55a932f2e7fa0a057bb.

Reverted https://github.com/pytorch/pytorch/pull/114547 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/114547#issuecomment-1836612143))
diff --git a/test/inductor/test_mkldnn_pattern_matcher.py b/test/inductor/test_mkldnn_pattern_matcher.py
index 757da17..9d0ec86 100644
--- a/test/inductor/test_mkldnn_pattern_matcher.py
+++ b/test/inductor/test_mkldnn_pattern_matcher.py
@@ -1759,43 +1759,6 @@
             check_dynamic=True,
         )
 
-    @skipIfNoDynamoSupport
-    @skipIfNoONEDNN
-    @skipIfRocm
-    def test_qat_bn_conv2d(self):
-        r"""
-        This testcase will quantize a single BN Conv2d module with qat flow.
-        """
-
-        class M(torch.nn.Module):
-            def __init__(
-                self,
-            ):
-                super().__init__()
-                self.conv = torch.nn.Conv2d(3, 3, 3)
-                self.bn1 = torch.nn.BatchNorm2d(3)
-                self.bn2 = torch.nn.BatchNorm2d(3)
-
-            def forward(self, x):
-                x = self.conv(self.bn1(x))
-                return self.bn2(x)
-
-        mod = M().train()
-        v = torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1)
-
-        def matcher_check_fn():
-            self.assertEqual(
-                counters["inductor"]["qconv2d_weight_prepack_matcher_count"], 1
-            )
-
-        self._test_common(
-            mod,
-            (v,),
-            check_quantization=True,
-            is_qat=True,
-            matcher_check_fn=matcher_check_fn,
-        )
-
 
 if __name__ == "__main__":
     if IS_LINUX and HAS_CPU and torch.backends.mkldnn.is_available():
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index 674e1ed..5e3a12c 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -1648,42 +1648,6 @@
         self._test_move_exported_model_to_eval_dropout(inplace=False)
         self._test_move_exported_model_to_eval_dropout(inplace=True)
 
-    def test_bn_move_exported_model_to_eval(self):
-        class M(torch.nn.Module):
-            def __init__(
-                self,
-            ):
-                super().__init__()
-                self.bn = torch.nn.BatchNorm2d(3)
-                self.conv = torch.nn.Conv2d(3, 3, 3)
-
-            def forward(self, x):
-                return self.conv(self.bn(x))
-
-        m = M().train()
-        example_inputs = (
-            torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1),
-        )
-
-        m = capture_pre_autograd_graph(m, example_inputs)
-
-        # Assert that bn op exists and is in train mode
-        batch_norm_node = None
-        for n in m.graph.nodes:
-            if n.target == torch.ops.aten._native_batch_norm_legit.default:
-                batch_norm_node = n
-                break
-        self.assertTrue(batch_norm_node is not None)
-        self.assertTrue(batch_norm_node.args[5])
-
-        # Do the subgraph rewriting
-        torch.ao.quantization.move_exported_model_to_eval(m)
-
-        # Assert that bn op is now in eval mode
-        targets = [n.target for n in m.graph.nodes]
-        self.assertTrue(torch.ops.aten._native_batch_norm_legit.default not in targets)
-        self.assertTrue(torch.ops.aten._native_batch_norm_legit_no_training.default in targets)
-
     def test_disallow_eval_train(self):
         m = TestHelperModules.ConvWithBNRelu(relu=True)
         example_inputs = (torch.rand(3, 3, 5, 5),)
diff --git a/test/quantization/pt2e/test_quantize_pt2e_qat.py b/test/quantization/pt2e/test_quantize_pt2e_qat.py
index 996824c..de6e3d6 100644
--- a/test/quantization/pt2e/test_quantize_pt2e_qat.py
+++ b/test/quantization/pt2e/test_quantize_pt2e_qat.py
@@ -159,8 +159,8 @@
         self.assertEqual(after_prepare_result_pt2e, after_prepare_result_fx)
 
         if verify_convert:
-            model_pt2e = convert_pt2e(model_pt2e)
             torch.ao.quantization.move_exported_model_to_eval(model_pt2e)
+            model_pt2e = convert_pt2e(model_pt2e)
             quant_result_pt2e = model_pt2e(*example_inputs)
             model_fx.eval()
             model_fx = _convert_to_reference_decomposed_fx(
diff --git a/torch/ao/quantization/pt2e/eval_utils.py b/torch/ao/quantization/pt2e/eval_utils.py
index 7699e61..c4874ed 100644
--- a/torch/ao/quantization/pt2e/eval_utils.py
+++ b/torch/ao/quantization/pt2e/eval_utils.py
@@ -45,68 +45,14 @@
         m.recompile()
 
 
-def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):
-    # TODO(Leslie): This function still fails to support custom momentum and eps value.
-    # Enable this support in future updates.
-
-    # Avoid circular dependencies
-    from .utils import get_aten_graph_module
-
-    # Needed to ensure subgraph matches are self-contained
-    m.graph.eliminate_dead_code()
-    m.recompile()
-
-    def bn_train(
-        x: torch.Tensor,
-        bn_weight: torch.Tensor,
-        bn_bias: torch.Tensor,
-        bn_running_mean: torch.Tensor,
-        bn_running_var: torch.Tensor,
-    ):
-        return F.batch_norm(
-            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=True
-        )
-
-    def bn_eval(
-        x: torch.Tensor,
-        bn_weight: torch.Tensor,
-        bn_bias: torch.Tensor,
-        bn_running_mean: torch.Tensor,
-        bn_running_var: torch.Tensor,
-    ):
-        return F.batch_norm(
-            x, bn_running_mean, bn_running_var, bn_weight, bn_bias, training=False
-        )
-
-    example_inputs = (
-        torch.randn(1, 1, 3, 3),  # x
-        torch.randn(1),  # bn_weight
-        torch.randn(1),  # bn_bias
-        torch.randn(1),  # bn_running_mean
-        torch.randn(1),  # bn_running_var
-    )
-    match_pattern = get_aten_graph_module(bn_train, example_inputs)
-    replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
-    from torch.fx.subgraph_rewriter import replace_pattern_with_filters
-
-    replace_pattern_with_filters(
-        m,
-        match_pattern,
-        replacement_pattern,
-        match_filters=[],
-        ignore_literals=True,
-    )
-    m.recompile()
-
-
 # TODO: also support move_exported_model_to_train
+# TODO: also support standalone batchnorm
 def _move_exported_model_to_eval(model: torch.fx.GraphModule):
     """
     Move an exported GraphModule to eval mode.
 
-    This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
+    This is equivalent to model.eval() but only for certain special ops like dropout.
     QAT users should call this before performing inference on the model.
     """
     _replace_dropout_for_eval(model)
-    _replace_batchnorm_for_eval(model)
     return model