[quant][pt2] Add `move_exported_model_to_train` (#113492)

Summary: This is the equivalent API to `model.train()` for
exported models, analogous to `move_exported_model_to_eval`.

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_dropout
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_dropout_inplace
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_dropout_bn

Reviewers: jerryzh168, kimishpatel

Subscribers: jerryzh168, kimishpatel, supriyar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113492
Approved by: https://github.com/jerryzh168, https://github.com/tugsbayasgalan
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index eb2bc61..177c5ae 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -1498,7 +1498,22 @@
             qconfig_mapping,
         )
 
-    def _test_move_exported_model_to_eval_dropout(self, inplace=False):
+    def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload):
+        """
+        Return the first node matching the specified target, throwing an exception
+        if no such batch norm node is found.
+        """
+        for n in m.graph.nodes:
+            if n.target == target:
+                return n
+        raise ValueError("Did not find node with target ", target)
+
+    def _test_move_exported_model_dropout(self, inplace: bool):
+        """
+        Test switching dropout behavior between train and eval modes using
+        `move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
+        """
+
         class M(torch.nn.Module):
             def __init__(self):
                 super().__init__()
@@ -1510,71 +1525,75 @@
         example_inputs = (torch.randn(1),)
         m = M().train()
         m = capture_pre_autograd_graph(m, example_inputs)
+        if inplace:
+            target = torch.ops.aten.dropout_.default
+        else:
+            target = torch.ops.aten.dropout.default
 
         # Assert that dropout op exists and is in train mode
-        dropout_node = None
-        for n in m.graph.nodes:
-            if n.target == torch.ops.aten.native_dropout.default or n.target == torch.ops.aten.dropout_.default:
-                dropout_node = n
-                break
+        dropout_node = self._get_node(m, target)
         self.assertTrue(dropout_node is not None)
         self.assertTrue(dropout_node.args[2])
 
-        # Do the subgraph rewriting
+        # Move to eval
         torch.ao.quantization.move_exported_model_to_eval(m)
 
-        # Assert that dropout op is now replaced with a clone op
-        targets = [n.target for n in m.graph.nodes]
-        if inplace:
-            dropout_eval_node = None
-            for node in m.graph.nodes:
-                if node.target == torch.ops.aten.dropout_.default:
-                    dropout_eval_node = node
-            self.assertTrue(dropout_eval_node is not None)
-            self.assertFalse(dropout_eval_node.args[2])
-        else:
-            self.assertTrue(torch.ops.aten.clone.default in targets)
-            self.assertTrue(torch.ops.aten.native_dropout.default not in targets)
+        # Assert that dropout op is now in eval mode
+        dropout_node = self._get_node(m, target)
+        self.assertTrue(dropout_node is not None)
+        self.assertTrue(not dropout_node.args[2])
 
-    def test_move_exported_model_to_eval(self):
-        self._test_move_exported_model_to_eval_dropout(inplace=False)
-        self._test_move_exported_model_to_eval_dropout(inplace=True)
+        # Move back to train
+        torch.ao.quantization.move_exported_model_to_train(m)
 
-    def test_bn_move_exported_model_to_eval(self):
+        # Assert that dropout op is now in train mode again
+        dropout_node = self._get_node(m, target)
+        self.assertTrue(dropout_node is not None)
+        self.assertTrue(dropout_node.args[2])
+
+    def test_move_exported_model_dropout(self):
+        self._test_move_exported_model_dropout(inplace=False)
+
+    def test_move_exported_model_dropout_inplace(self):
+        self._test_move_exported_model_dropout(inplace=True)
+
+    def test_move_exported_model_bn(self):
+        """
+        Test switching batch_norm behavior between train and eval modes using
+        `move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
+        """
+
         class M(torch.nn.Module):
-            def __init__(
-                self,
-            ):
+            def __init__(self):
                 super().__init__()
                 self.bn = torch.nn.BatchNorm2d(3)
-                self.conv = torch.nn.Conv2d(3, 3, 3)
 
             def forward(self, x):
-                return self.conv(self.bn(x))
+                return self.bn(x)
 
+        example_inputs = (torch.randn(1, 3, 3, 3),)
         m = M().train()
-        example_inputs = (
-            torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1),
-        )
-
         m = capture_pre_autograd_graph(m, example_inputs)
 
-        # Assert that bn op exists and is in train mode
-        batch_norm_node = None
-        for n in m.graph.nodes:
-            if n.target == torch.ops.aten._native_batch_norm_legit.default:
-                batch_norm_node = n
-                break
-        self.assertTrue(batch_norm_node is not None)
-        self.assertTrue(batch_norm_node.args[5])
+        # Assert that batch norm op exists and is in train mode
+        bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
+        self.assertTrue(bn_node is not None)
+        self.assertTrue(bn_node.args[5])
 
-        # Do the subgraph rewriting
+        # Move to eval
         torch.ao.quantization.move_exported_model_to_eval(m)
 
-        # Assert that bn op is now in eval mode
-        targets = [n.target for n in m.graph.nodes]
-        self.assertTrue(torch.ops.aten._native_batch_norm_legit.default not in targets)
-        self.assertTrue(torch.ops.aten._native_batch_norm_legit_no_training.default in targets)
+        # Assert that batch norm op is now in eval mode
+        bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit_no_training.default)
+        self.assertTrue(bn_node is not None)
+
+        # Move to train
+        torch.ao.quantization.move_exported_model_to_train(m)
+
+        # Assert that batch norm op is now in train mode again
+        bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
+        self.assertTrue(bn_node is not None)
+        self.assertTrue(bn_node.args[5])
 
     def test_disallow_eval_train(self):
         m = TestHelperModules.ConvWithBNRelu(relu=True)
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index ddfddb5..1f3e55d 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -156,7 +156,14 @@
     else:
         constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
 
-    decomp_table = {op: op.decompose for op in FunctionalTensor.maybe_aliasing_or_mutating_ops}
+    # Do not decompose dropout for exported models, because in eval mode the dropout
+    # op disappears from the graph, which makes it difficult to switch to train mode.
+    # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
+    decomp_table = {
+        op: op.decompose
+        for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
+        if op != torch.ops.aten.dropout.default
+    }
     with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
         m = torch._dynamo.export(
             f,
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index 04323f6..23e469b 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -13,6 +13,7 @@
 from .quantize_jit import *  # noqa: F403
 from .stubs import *  # noqa: F403
 from .pt2e.eval_utils import _move_exported_model_to_eval as move_exported_model_to_eval
+from .pt2e.eval_utils import _move_exported_model_to_train as move_exported_model_to_train
 from .pt2e.generate_numeric_debug_handle import generate_numeric_debug_handle  # noqa: F401
 from typing import Union, List, Callable, Tuple, Optional
 from torch import Tensor
@@ -122,6 +123,7 @@
     "get_static_quant_module_class",
     "load_observer_state_dict",
     "move_exported_model_to_eval",
+    "move_exported_model_to_train",
     "no_observer_set",
     "per_channel_weight_observer_range_neg_127_to_127",
     "prepare",
diff --git a/torch/ao/quantization/pt2e/eval_utils.py b/torch/ao/quantization/pt2e/eval_utils.py
index 7699e61..6e47d38 100644
--- a/torch/ao/quantization/pt2e/eval_utils.py
+++ b/torch/ao/quantization/pt2e/eval_utils.py
@@ -2,15 +2,14 @@
 import torch.nn.functional as F
 
 
-def _replace_dropout_for_eval(m: torch.fx.GraphModule):
+def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
     """
-    Replace the aten training dropout pattern with a noop, intended for eval.
+    Switch dropout patterns in the model between train and eval modes.
 
-    For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
-    effectively turns these dropout ops into noops. For exported models, however,
-    this is not done automatically, since the aten dropout patterns previously generated
-    for training remain in the graph. Here we rewrite these dropout patterns with noops
-    to avoid incorrectly applying further dropout during eval.
+    Dropout has different behavior in train vs eval mode. For exported models,
+    however, calling `model.train()` or `model.eval()` does not automatically switch
+    the dropout behavior between the two modes, so here we need to rewrite the aten
+    dropout patterns manually to achieve the same effect.
 
     See https://github.com/pytorch/pytorch/issues/103681.
     """
@@ -30,8 +29,12 @@
             return F.dropout(x, p=0.5, training=False, inplace=inplace)
 
         example_inputs = (torch.randn(1),)
-        match_pattern = get_aten_graph_module(dropout_train, example_inputs)
-        replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
+        if train_to_eval:
+            match_pattern = get_aten_graph_module(dropout_train, example_inputs)
+            replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
+        else:
+            match_pattern = get_aten_graph_module(dropout_eval, example_inputs)
+            replacement_pattern = get_aten_graph_module(dropout_train, example_inputs)
 
         from torch.fx.subgraph_rewriter import replace_pattern_with_filters
 
@@ -45,7 +48,15 @@
         m.recompile()
 
 
-def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):
+def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
+    """
+    Switch batchnorm patterns in the model between train and eval modes.
+
+    Batchnorm has different behavior in train vs eval mode. For exported models,
+    however, calling `model.train()` or `model.eval()` does not automatically switch
+    the batchnorm behavior between the two modes, so here we need to rewrite the aten
+    batchnorm patterns manually to achieve the same effect.
+    """
     # TODO(Leslie): This function still fails to support custom momentum and eps value.
     # Enable this support in future updates.
 
@@ -85,8 +96,13 @@
         torch.randn(1),  # bn_running_mean
         torch.randn(1),  # bn_running_var
     )
-    match_pattern = get_aten_graph_module(bn_train, example_inputs)
-    replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
+    if train_to_eval:
+        match_pattern = get_aten_graph_module(bn_train, example_inputs)
+        replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
+    else:
+        match_pattern = get_aten_graph_module(bn_eval, example_inputs)
+        replacement_pattern = get_aten_graph_module(bn_train, example_inputs)
+
     from torch.fx.subgraph_rewriter import replace_pattern_with_filters
 
     replace_pattern_with_filters(
@@ -99,7 +115,6 @@
     m.recompile()
 
 
-# TODO: also support move_exported_model_to_train
 def _move_exported_model_to_eval(model: torch.fx.GraphModule):
     """
     Move an exported GraphModule to eval mode.
@@ -107,6 +122,18 @@
     This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
     QAT users should call this before performing inference on the model.
     """
-    _replace_dropout_for_eval(model)
-    _replace_batchnorm_for_eval(model)
+    _replace_dropout(model, train_to_eval=True)
+    _replace_batchnorm(model, train_to_eval=True)
+    return model
+
+
+def _move_exported_model_to_train(model: torch.fx.GraphModule):
+    """
+    Move an exported GraphModule to train mode.
+
+    This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
+    QAT users should call this before performing training on the model.
+    """
+    _replace_dropout(model, train_to_eval=False)
+    _replace_batchnorm(model, train_to_eval=False)
     return model