[fx][const fold] Add test/example for skipping quant/dequant pattern (#68378)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68378
Add test/example for skipping quant/dequant pattern
Reviewed By: jfix71
Differential Revision: D32410544
fbshipit-source-id: e63419a01a097e4c570c3861d79d573cabc0b294
diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py
index a0078bf..2f70e9b 100644
--- a/test/fx/test_fx_const_fold.py
+++ b/test/fx/test_fx_const_fold.py
@@ -5,6 +5,7 @@
import torch
import torch.fx
from torch.fx.experimental import const_fold
+from torch.fx.experimental.fx_acc import acc_tracer, acc_ops
from torch.testing._internal.common_utils import TestCase
@@ -585,3 +586,47 @@
base_result = mod(in_x)
self.assertTrue(torch.equal(fold_result[0], base_result[0]))
self.assertTrue(torch.equal(fold_result[1], base_result[1]))
+
+ def test_check_skip_folding_quant_dequant_pattern(self):
+ r"""
+ Set up skip_folding_quant_dequant function to skip quant/dequant pattern.
+ This example shows how to use skip_folding_node_fn.
+ """
+
+ class ConstFoldTestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.weight = torch.nn.Parameter(torch.randn(4, 4))
+ self.bias = torch.nn.Parameter(torch.randn(4))
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ quant_weight = torch.quantize_per_tensor(self.weight, 0.5, 3, torch.quint8)
+ dequant_weight = torch.dequantize(quant_weight)
+ output = torch.nn.functional.linear(x, dequant_weight, self.bias)
+ return self.relu(output)
+
+ mod = ConstFoldTestModule()
+ in_x = torch.randn(2, 4)
+ gm = acc_tracer.trace(mod, in_x)
+
+ def skip_folding_quant_dequant(node: torch.fx.Node):
+ if node.target != acc_ops.quantize_per_tensor:
+ return False
+ # If quantize_per_node -> dequantize, then skip folding.
+ for user in node.users:
+ if user.target == acc_ops.dequantize:
+ return True
+ return False
+
+ gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
+ gm, skip_folding_node_fn=skip_folding_quant_dequant
+ )
+
+ # Check that the folded graph module is None, since there was no folding to do.
+ self.assertTrue(gm_folded.const_subgraph_module is None)
+
+ # Now run both folded and non-folded to check results equal.
+ fold_result = gm_folded(in_x)
+ base_result = mod(in_x)
+ self.assertTrue(torch.equal(fold_result, base_result))