[quant][pt2] Add `move_exported_model_to_train` (#113492)
Summary: This is the equivalent API to `model.train()` for
exported models, analogous to `move_exported_model_to_eval`.
Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_dropout
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_dropout_inplace
python test/test_quantization.py TestQuantizePT2E.test_move_exported_model_dropout_bn
Reviewers: jerryzh168, kimishpatel
Subscribers: jerryzh168, kimishpatel, supriyar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113492
Approved by: https://github.com/jerryzh168, https://github.com/tugsbayasgalan
diff --git a/test/quantization/pt2e/test_quantize_pt2e.py b/test/quantization/pt2e/test_quantize_pt2e.py
index eb2bc61..177c5ae 100644
--- a/test/quantization/pt2e/test_quantize_pt2e.py
+++ b/test/quantization/pt2e/test_quantize_pt2e.py
@@ -1498,7 +1498,22 @@
qconfig_mapping,
)
- def _test_move_exported_model_to_eval_dropout(self, inplace=False):
+ def _get_node(self, m: torch.fx.GraphModule, target: torch._ops.OpOverload):
+ """
+ Return the first node matching the specified target, throwing an exception
+ if no such batch norm node is found.
+ """
+ for n in m.graph.nodes:
+ if n.target == target:
+ return n
+ raise ValueError("Did not find node with target ", target)
+
+ def _test_move_exported_model_dropout(self, inplace: bool):
+ """
+ Test switching dropout behavior between train and eval modes using
+ `move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
+ """
+
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@@ -1510,71 +1525,75 @@
example_inputs = (torch.randn(1),)
m = M().train()
m = capture_pre_autograd_graph(m, example_inputs)
+ if inplace:
+ target = torch.ops.aten.dropout_.default
+ else:
+ target = torch.ops.aten.dropout.default
# Assert that dropout op exists and is in train mode
- dropout_node = None
- for n in m.graph.nodes:
- if n.target == torch.ops.aten.native_dropout.default or n.target == torch.ops.aten.dropout_.default:
- dropout_node = n
- break
+ dropout_node = self._get_node(m, target)
self.assertTrue(dropout_node is not None)
self.assertTrue(dropout_node.args[2])
- # Do the subgraph rewriting
+ # Move to eval
torch.ao.quantization.move_exported_model_to_eval(m)
- # Assert that dropout op is now replaced with a clone op
- targets = [n.target for n in m.graph.nodes]
- if inplace:
- dropout_eval_node = None
- for node in m.graph.nodes:
- if node.target == torch.ops.aten.dropout_.default:
- dropout_eval_node = node
- self.assertTrue(dropout_eval_node is not None)
- self.assertFalse(dropout_eval_node.args[2])
- else:
- self.assertTrue(torch.ops.aten.clone.default in targets)
- self.assertTrue(torch.ops.aten.native_dropout.default not in targets)
+ # Assert that dropout op is now in eval mode
+ dropout_node = self._get_node(m, target)
+ self.assertTrue(dropout_node is not None)
+ self.assertTrue(not dropout_node.args[2])
- def test_move_exported_model_to_eval(self):
- self._test_move_exported_model_to_eval_dropout(inplace=False)
- self._test_move_exported_model_to_eval_dropout(inplace=True)
+ # Move back to train
+ torch.ao.quantization.move_exported_model_to_train(m)
- def test_bn_move_exported_model_to_eval(self):
+ # Assert that dropout op is now in train mode again
+ dropout_node = self._get_node(m, target)
+ self.assertTrue(dropout_node is not None)
+ self.assertTrue(dropout_node.args[2])
+
+ def test_move_exported_model_dropout(self):
+ self._test_move_exported_model_dropout(inplace=False)
+
+ def test_move_exported_model_dropout_inplace(self):
+ self._test_move_exported_model_dropout(inplace=True)
+
+ def test_move_exported_model_bn(self):
+ """
+ Test switching batch_norm behavior between train and eval modes using
+ `move_exported_model_to_eval` and `move_exported_model_to_train` APIs.
+ """
+
class M(torch.nn.Module):
- def __init__(
- self,
- ):
+ def __init__(self):
super().__init__()
self.bn = torch.nn.BatchNorm2d(3)
- self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
- return self.conv(self.bn(x))
+ return self.bn(x)
+ example_inputs = (torch.randn(1, 3, 3, 3),)
m = M().train()
- example_inputs = (
- torch.randn((1, 3, 8, 8), dtype=torch.float32, requires_grad=True).add(1),
- )
-
m = capture_pre_autograd_graph(m, example_inputs)
- # Assert that bn op exists and is in train mode
- batch_norm_node = None
- for n in m.graph.nodes:
- if n.target == torch.ops.aten._native_batch_norm_legit.default:
- batch_norm_node = n
- break
- self.assertTrue(batch_norm_node is not None)
- self.assertTrue(batch_norm_node.args[5])
+ # Assert that batch norm op exists and is in train mode
+ bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
+ self.assertTrue(bn_node is not None)
+ self.assertTrue(bn_node.args[5])
- # Do the subgraph rewriting
+ # Move to eval
torch.ao.quantization.move_exported_model_to_eval(m)
- # Assert that bn op is now in eval mode
- targets = [n.target for n in m.graph.nodes]
- self.assertTrue(torch.ops.aten._native_batch_norm_legit.default not in targets)
- self.assertTrue(torch.ops.aten._native_batch_norm_legit_no_training.default in targets)
+ # Assert that batch norm op is now in eval mode
+ bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit_no_training.default)
+ self.assertTrue(bn_node is not None)
+
+ # Move to train
+ torch.ao.quantization.move_exported_model_to_train(m)
+
+ # Assert that batch norm op is now in train mode again
+ bn_node = self._get_node(m, torch.ops.aten._native_batch_norm_legit.default)
+ self.assertTrue(bn_node is not None)
+ self.assertTrue(bn_node.args[5])
def test_disallow_eval_train(self):
m = TestHelperModules.ConvWithBNRelu(relu=True)
diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py
index ddfddb5..1f3e55d 100644
--- a/torch/_export/__init__.py
+++ b/torch/_export/__init__.py
@@ -156,7 +156,14 @@
else:
constraints = _process_dynamic_shapes(f, args, kwargs, dynamic_shapes)
- decomp_table = {op: op.decompose for op in FunctionalTensor.maybe_aliasing_or_mutating_ops}
+ # Do not decompose dropout for exported models, because in eval mode the dropout
+ # op disappears from the graph, which makes it difficult to switch to train mode.
+ # See https://github.com/pytorch/pytorch/pull/115258#issuecomment-1900755832.
+ decomp_table = {
+ op: op.decompose
+ for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
+ if op != torch.ops.aten.dropout.default
+ }
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
m = torch._dynamo.export(
f,
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index 04323f6..23e469b 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -13,6 +13,7 @@
from .quantize_jit import * # noqa: F403
from .stubs import * # noqa: F403
from .pt2e.eval_utils import _move_exported_model_to_eval as move_exported_model_to_eval
+from .pt2e.eval_utils import _move_exported_model_to_train as move_exported_model_to_train
from .pt2e.generate_numeric_debug_handle import generate_numeric_debug_handle # noqa: F401
from typing import Union, List, Callable, Tuple, Optional
from torch import Tensor
@@ -122,6 +123,7 @@
"get_static_quant_module_class",
"load_observer_state_dict",
"move_exported_model_to_eval",
+ "move_exported_model_to_train",
"no_observer_set",
"per_channel_weight_observer_range_neg_127_to_127",
"prepare",
diff --git a/torch/ao/quantization/pt2e/eval_utils.py b/torch/ao/quantization/pt2e/eval_utils.py
index 7699e61..6e47d38 100644
--- a/torch/ao/quantization/pt2e/eval_utils.py
+++ b/torch/ao/quantization/pt2e/eval_utils.py
@@ -2,15 +2,14 @@
import torch.nn.functional as F
-def _replace_dropout_for_eval(m: torch.fx.GraphModule):
+def _replace_dropout(m: torch.fx.GraphModule, train_to_eval: bool):
"""
- Replace the aten training dropout pattern with a noop, intended for eval.
+ Switch dropout patterns in the model between train and eval modes.
- For models with dropout torch ops (nn.Dropout, F.dropout), calling model.eval()
- effectively turns these dropout ops into noops. For exported models, however,
- this is not done automatically, since the aten dropout patterns previously generated
- for training remain in the graph. Here we rewrite these dropout patterns with noops
- to avoid incorrectly applying further dropout during eval.
+ Dropout has different behavior in train vs eval mode. For exported models,
+ however, calling `model.train()` or `model.eval()` does not automatically switch
+ the dropout behavior between the two modes, so here we need to rewrite the aten
+ dropout patterns manually to achieve the same effect.
See https://github.com/pytorch/pytorch/issues/103681.
"""
@@ -30,8 +29,12 @@
return F.dropout(x, p=0.5, training=False, inplace=inplace)
example_inputs = (torch.randn(1),)
- match_pattern = get_aten_graph_module(dropout_train, example_inputs)
- replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
+ if train_to_eval:
+ match_pattern = get_aten_graph_module(dropout_train, example_inputs)
+ replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
+ else:
+ match_pattern = get_aten_graph_module(dropout_eval, example_inputs)
+ replacement_pattern = get_aten_graph_module(dropout_train, example_inputs)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
@@ -45,7 +48,15 @@
m.recompile()
-def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):
+def _replace_batchnorm(m: torch.fx.GraphModule, train_to_eval: bool):
+ """
+ Switch batchnorm patterns in the model between train and eval modes.
+
+ Batchnorm has different behavior in train vs eval mode. For exported models,
+ however, calling `model.train()` or `model.eval()` does not automatically switch
+ the batchnorm behavior between the two modes, so here we need to rewrite the aten
+ batchnorm patterns manually to achieve the same effect.
+ """
# TODO(Leslie): This function still fails to support custom momentum and eps value.
# Enable this support in future updates.
@@ -85,8 +96,13 @@
torch.randn(1), # bn_running_mean
torch.randn(1), # bn_running_var
)
- match_pattern = get_aten_graph_module(bn_train, example_inputs)
- replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
+ if train_to_eval:
+ match_pattern = get_aten_graph_module(bn_train, example_inputs)
+ replacement_pattern = get_aten_graph_module(bn_eval, example_inputs)
+ else:
+ match_pattern = get_aten_graph_module(bn_eval, example_inputs)
+ replacement_pattern = get_aten_graph_module(bn_train, example_inputs)
+
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
replace_pattern_with_filters(
@@ -99,7 +115,6 @@
m.recompile()
-# TODO: also support move_exported_model_to_train
def _move_exported_model_to_eval(model: torch.fx.GraphModule):
"""
Move an exported GraphModule to eval mode.
@@ -107,6 +122,18 @@
This is equivalent to model.eval() but only for certain special ops like dropout, batchnorm.
QAT users should call this before performing inference on the model.
"""
- _replace_dropout_for_eval(model)
- _replace_batchnorm_for_eval(model)
+ _replace_dropout(model, train_to_eval=True)
+ _replace_batchnorm(model, train_to_eval=True)
+ return model
+
+
+def _move_exported_model_to_train(model: torch.fx.GraphModule):
+ """
+ Move an exported GraphModule to train mode.
+
+ This is equivalent to model.train() but only for certain special ops like dropout, batchnorm.
+ QAT users should call this before performing training on the model.
+ """
+ _replace_dropout(model, train_to_eval=False)
+ _replace_batchnorm(model, train_to_eval=False)
return model