[fx][const fold] Add test/example for skipping quant/dequant pattern (#68378)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/68378

Add test/example for skipping quant/dequant pattern

Reviewed By: jfix71

Differential Revision: D32410544

fbshipit-source-id: e63419a01a097e4c570c3861d79d573cabc0b294
diff --git a/test/fx/test_fx_const_fold.py b/test/fx/test_fx_const_fold.py
index a0078bf..2f70e9b 100644
--- a/test/fx/test_fx_const_fold.py
+++ b/test/fx/test_fx_const_fold.py
@@ -5,6 +5,7 @@
 import torch
 import torch.fx
 from torch.fx.experimental import const_fold
+from torch.fx.experimental.fx_acc import acc_tracer, acc_ops
 from torch.testing._internal.common_utils import TestCase
 
 
@@ -585,3 +586,47 @@
         base_result = mod(in_x)
         self.assertTrue(torch.equal(fold_result[0], base_result[0]))
         self.assertTrue(torch.equal(fold_result[1], base_result[1]))
+
+    def test_check_skip_folding_quant_dequant_pattern(self):
+        r"""
+        Set up skip_folding_quant_dequant function to skip quant/dequant pattern.
+        This example shows how to use skip_folding_node_fn.
+        """
+
+        class ConstFoldTestModule(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.weight = torch.nn.Parameter(torch.randn(4, 4))
+                self.bias = torch.nn.Parameter(torch.randn(4))
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, x):
+                quant_weight = torch.quantize_per_tensor(self.weight, 0.5, 3, torch.quint8)
+                dequant_weight = torch.dequantize(quant_weight)
+                output = torch.nn.functional.linear(x, dequant_weight, self.bias)
+                return self.relu(output)
+
+        mod = ConstFoldTestModule()
+        in_x = torch.randn(2, 4)
+        gm = acc_tracer.trace(mod, in_x)
+
+        def skip_folding_quant_dequant(node: torch.fx.Node):
+            if node.target != acc_ops.quantize_per_tensor:
+                return False
+            # If quantize_per_node -> dequantize, then skip folding.
+            for user in node.users:
+                if user.target == acc_ops.dequantize:
+                    return True
+            return False
+
+        gm_folded: const_fold.FoldedGraphModule = const_fold.split_const_subgraphs(
+            gm, skip_folding_node_fn=skip_folding_quant_dequant
+        )
+
+        # Check that the folded graph module is None, since there was no folding to do.
+        self.assertTrue(gm_folded.const_subgraph_module is None)
+
+        # Now run both folded and non-folded to check results equal.
+        fold_result = gm_folded(in_x)
+        base_result = mod(in_x)
+        self.assertTrue(torch.equal(fold_result, base_result))