ns for fx: move relatedness mapping to mappings file (#57171)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57171
No logic change, just moving the mapping to a file where
the other mappings are.
Test Plan:
```
python test/test_quantization.py TestFXNumericSuiteCoreAPIs
```
Imported from OSS
Reviewed By: jerryzh168
Differential Revision: D28077978
fbshipit-source-id: 4049d6a498156a5dffe3a03d2f4abc79da7bf907
diff --git a/test/quantization/test_numeric_suite_fx.py b/test/quantization/test_numeric_suite_fx.py
index 169f039..46b5a25 100644
--- a/test/quantization/test_numeric_suite_fx.py
+++ b/test/quantization/test_numeric_suite_fx.py
@@ -35,7 +35,6 @@
from torch.quantization.fx.pattern_utils import get_default_quant_patterns
import torch.quantization.fx.quantization_patterns as qp
from torch.quantization.ns.pattern_utils import (
- get_base_name_to_sets_of_related_ops,
get_type_a_related_to_b,
)
from torch.quantization.ns.graph_matcher import (
@@ -45,6 +44,7 @@
from torch.quantization.ns.mappings import (
get_node_type_to_io_type_map,
get_unmatchable_types_map,
+ get_base_name_to_sets_of_related_ops,
)
from torch.quantization._numeric_suite_fx import (
extract_weights,
diff --git a/torch/quantization/_numeric_suite_fx.py b/torch/quantization/_numeric_suite_fx.py
index e741c21..c4cc2d4 100644
--- a/torch/quantization/_numeric_suite_fx.py
+++ b/torch/quantization/_numeric_suite_fx.py
@@ -5,9 +5,11 @@
import torch.quantization.quantize_fx as quantize_fx
from torch.fx import GraphModule
from torch.fx.graph import Node
+from torch.quantization.ns.mappings import (
+ get_base_name_to_sets_of_related_ops,
+)
from torch.quantization.ns.graph_matcher import (
get_matching_subgraph_pairs,
- get_base_name_to_sets_of_related_ops,
get_type_a_related_to_b,
)
diff --git a/torch/quantization/ns/graph_matcher.py b/torch/quantization/ns/graph_matcher.py
index 69624db..62f5d62 100644
--- a/torch/quantization/ns/graph_matcher.py
+++ b/torch/quantization/ns/graph_matcher.py
@@ -9,10 +9,10 @@
from .utils import getattr_from_fqn
from .ns_types import NSSubgraph, NSNodeTargetType
from .mappings import (
+ get_base_name_to_sets_of_related_ops,
get_unmatchable_types_map,
)
from .pattern_utils import (
- get_base_name_to_sets_of_related_ops,
get_type_a_related_to_b,
get_reversed_fusions,
end_node_matches_reversed_fusion,
diff --git a/torch/quantization/ns/mappings.py b/torch/quantization/ns/mappings.py
index 16de337..35a4683 100644
--- a/torch/quantization/ns/mappings.py
+++ b/torch/quantization/ns/mappings.py
@@ -16,6 +16,217 @@
from typing import Set, Dict
+
+def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
+ base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {
+ # conv modules
+ 'torch.nn.Conv1d': set([
+ nn.Conv1d,
+ nnq.Conv1d,
+ nniqat.ConvBn1d,
+ nniqat.ConvBnReLU1d,
+ nniq.ConvReLU1d,
+ nni.ConvReLU1d,
+ ]),
+ 'torch.nn.Conv2d': set([
+ nn.Conv2d,
+ nnq.Conv2d,
+ nnqat.Conv2d,
+ nniqat.ConvBn2d,
+ nniqat.ConvBnReLU2d,
+ nniqat.ConvReLU2d,
+ nniq.ConvReLU2d,
+ nni.ConvReLU2d,
+ ]),
+ 'torch.nn.Conv3d': set([
+ nn.Conv3d,
+ nnq.Conv3d,
+ nnqat.Conv3d,
+ nniqat.ConvBn3d,
+ nniqat.ConvBnReLU3d,
+ nniqat.ConvReLU3d,
+ nniq.ConvReLU3d,
+ nni.ConvReLU3d,
+ ]),
+ # conv functionals
+ 'torch.nn.functional.conv1d': set([
+ F.conv1d,
+ toq.conv1d,
+ toq.conv1d_relu,
+ ]),
+ 'torch.nn.functional.conv2d': set([
+ F.conv2d,
+ toq.conv2d,
+ toq.conv2d_relu,
+ ]),
+ 'torch.nn.functional.conv3d': set([
+ F.conv3d,
+ toq.conv3d,
+ toq.conv3d_relu,
+ ]),
+ # linear modules
+ 'torch.nn.Linear': set([
+ nn.Linear,
+ nnq.Linear,
+ nni.LinearReLU,
+ nniq.LinearReLU,
+ nnqat.Linear,
+ nnqd.Linear,
+ nniqat.LinearReLU,
+ nn.modules.linear._LinearWithBias,
+ ]),
+ # linear functionals
+ 'torch.nn.functional.linear': set([
+ F.linear,
+ toq.linear,
+ toq.linear_relu,
+ ]),
+ # LSTM
+ 'torch.nn.LSTM': set([
+ nn.LSTM,
+ nnqd.LSTM,
+ ]),
+ # add
+ 'torch.add': set([
+ torch.add,
+ toq.add,
+ operator.add, # x + y
+ toq.add_relu,
+ ]),
+ # cat
+ 'torch.cat': set([
+ torch.cat,
+ toq.cat,
+ ]),
+ # mul
+ 'torch.mul': set([
+ torch.mul,
+ toq.mul,
+ operator.mul,
+ toq.mul_relu,
+ ]),
+ # relu
+ 'torch.relu': set([
+ F.relu,
+ ]),
+ # maxpool2d
+ 'torch.nn.MaxPool2d': set([
+ nn.MaxPool2d,
+ ]),
+ # sigmoid
+ 'torch.sigmoid': set([
+ torch.sigmoid,
+ 'sigmoid',
+ ]),
+ # BatchNorm
+ 'torch.nn.BatchNorm2d': set([
+ nn.BatchNorm2d,
+ nnq.BatchNorm2d,
+ ]),
+ 'torch.nn.BatchNorm3d': set([
+ nn.BatchNorm3d,
+ nnq.BatchNorm3d,
+ ]),
+ # ConvTranspose
+ 'torch.nn.ConvTranspose1d': set([
+ nn.ConvTranspose1d,
+ nnq.ConvTranspose1d,
+ ]),
+ 'torch.nn.ConvTranspose2d': set([
+ nn.ConvTranspose2d,
+ nnq.ConvTranspose2d,
+ ]),
+ # ELU
+ 'torch.nn.ELU': set([
+ nn.ELU,
+ nnq.ELU,
+ ]),
+ # Embedding
+ 'torch.nn.Embedding': set([
+ nn.Embedding,
+ nnq.Embedding,
+ ]),
+ # EmbeddingBag
+ 'torch.nn.EmbeddingBag': set([
+ nn.EmbeddingBag,
+ nnq.EmbeddingBag,
+ ]),
+ # GroupNorm
+ 'torch.nn.GroupNorm': set([
+ nn.GroupNorm,
+ nnq.GroupNorm,
+ ]),
+ # Hardswish
+ 'torch.nn.Hardswish': set([
+ nn.Hardswish,
+ nnq.Hardswish,
+ ]),
+ # InstanceNorm
+ 'torch.nn.InstanceNorm1d': set([
+ nn.InstanceNorm1d,
+ nnq.InstanceNorm1d,
+ ]),
+ 'torch.nn.InstanceNorm2d': set([
+ nn.InstanceNorm2d,
+ nnq.InstanceNorm2d,
+ ]),
+ 'torch.nn.InstanceNorm3d': set([
+ nn.InstanceNorm3d,
+ nnq.InstanceNorm3d,
+ ]),
+ # LayerNorm
+ 'torch.nn.LayerNorm': set([
+ nn.LayerNorm,
+ nnq.LayerNorm,
+ ]),
+ # LeakyReLU
+ 'torch.nn.LeakyReLU': set([
+ nn.LeakyReLU,
+ nnq.LeakyReLU,
+ ]),
+ # ReLU6
+ 'torch.nn.ReLU6': set([
+ nn.ReLU6,
+ nnq.ReLU6,
+ ]),
+ # BNReLU2d
+ 'torch.nn.intrinsic.BNReLU2d': set([
+ nni.BNReLU2d,
+ nniq.BNReLU2d,
+ ]),
+ 'torch.nn.intrinsic.BNReLU3d': set([
+ nni.BNReLU3d,
+ nniq.BNReLU3d,
+ ]),
+ # F.elu
+ 'torch.nn.functional.elu': set([
+ F.elu,
+ toq.elu,
+ ]),
+ # F.hardswish
+ 'torch.nn.functional.hardswish': set([
+ F.hardswish,
+ toq.hardswish,
+ ]),
+ # F.instance_norm
+ 'torch.nn.functional.instance_norm': set([
+ F.instance_norm,
+ toq.instance_norm,
+ ]),
+ # F.layer_norm
+ 'torch.nn.functional.layer_norm': set([
+ F.layer_norm,
+ toq.layer_norm,
+ ]),
+ # F.leaky_relu
+ 'torch.nn.functional.leaky_relu': set([
+ F.leaky_relu,
+ toq.leaky_relu,
+ ]),
+ }
+ return base_name_to_sets_of_related_ops
+
+
# TODO(future PR): clean this up
def get_node_type_to_io_type_map() -> Dict[str, Set[NSNodeTargetType]]:
FUNS_IO_TYPE_FP32: Set[NSNodeTargetType] = set([
diff --git a/torch/quantization/ns/pattern_utils.py b/torch/quantization/ns/pattern_utils.py
index 6498d99..1ddde72 100644
--- a/torch/quantization/ns/pattern_utils.py
+++ b/torch/quantization/ns/pattern_utils.py
@@ -1,17 +1,7 @@
-import operator
-
import torch
-import torch.nn as nn
import torch.nn.functional as F
toq = torch.ops.quantized
-import torch.nn.quantized as nnq
-import torch.nn.quantized.dynamic as nnqd
-import torch.nn.qat as nnqat
-import torch.nn.intrinsic.quantized as nniq
-import torch.nn.intrinsic.qat as nniqat
-import torch.nn.intrinsic as nni
-
from torch.fx import GraphModule
from torch.fx.graph import Node
@@ -21,215 +11,6 @@
from typing import Dict, Tuple, Set, Callable, Any, Union
-def get_base_name_to_sets_of_related_ops() -> Dict[str, Set[NSNodeTargetType]]:
- base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]] = {
- # conv modules
- 'torch.nn.Conv1d': set([
- nn.Conv1d,
- nnq.Conv1d,
- nniqat.ConvBn1d,
- nniqat.ConvBnReLU1d,
- nniq.ConvReLU1d,
- nni.ConvReLU1d,
- ]),
- 'torch.nn.Conv2d': set([
- nn.Conv2d,
- nnq.Conv2d,
- nnqat.Conv2d,
- nniqat.ConvBn2d,
- nniqat.ConvBnReLU2d,
- nniqat.ConvReLU2d,
- nniq.ConvReLU2d,
- nni.ConvReLU2d,
- ]),
- 'torch.nn.Conv3d': set([
- nn.Conv3d,
- nnq.Conv3d,
- nnqat.Conv3d,
- nniqat.ConvBn3d,
- nniqat.ConvBnReLU3d,
- nniqat.ConvReLU3d,
- nniq.ConvReLU3d,
- nni.ConvReLU3d,
- ]),
- # conv functionals
- 'torch.nn.functional.conv1d': set([
- F.conv1d,
- toq.conv1d,
- toq.conv1d_relu,
- ]),
- 'torch.nn.functional.conv2d': set([
- F.conv2d,
- toq.conv2d,
- toq.conv2d_relu,
- ]),
- 'torch.nn.functional.conv3d': set([
- F.conv3d,
- toq.conv3d,
- toq.conv3d_relu,
- ]),
- # linear modules
- 'torch.nn.Linear': set([
- nn.Linear,
- nnq.Linear,
- nni.LinearReLU,
- nniq.LinearReLU,
- nnqat.Linear,
- nnqd.Linear,
- nniqat.LinearReLU,
- nn.modules.linear._LinearWithBias,
- ]),
- # linear functionals
- 'torch.nn.functional.linear': set([
- F.linear,
- toq.linear,
- toq.linear_relu,
- ]),
- # LSTM
- 'torch.nn.LSTM': set([
- nn.LSTM,
- nnqd.LSTM,
- ]),
- # add
- 'torch.add': set([
- torch.add,
- toq.add,
- operator.add, # x + y
- toq.add_relu,
- ]),
- # cat
- 'torch.cat': set([
- torch.cat,
- toq.cat,
- ]),
- # mul
- 'torch.mul': set([
- torch.mul,
- toq.mul,
- operator.mul,
- toq.mul_relu,
- ]),
- # relu
- 'torch.relu': set([
- F.relu,
- ]),
- # maxpool2d
- 'torch.nn.MaxPool2d': set([
- nn.MaxPool2d,
- ]),
- # sigmoid
- 'torch.sigmoid': set([
- torch.sigmoid,
- 'sigmoid',
- ]),
- # BatchNorm
- 'torch.nn.BatchNorm2d': set([
- nn.BatchNorm2d,
- nnq.BatchNorm2d,
- ]),
- 'torch.nn.BatchNorm3d': set([
- nn.BatchNorm3d,
- nnq.BatchNorm3d,
- ]),
- # ConvTranspose
- 'torch.nn.ConvTranspose1d': set([
- nn.ConvTranspose1d,
- nnq.ConvTranspose1d,
- ]),
- 'torch.nn.ConvTranspose2d': set([
- nn.ConvTranspose2d,
- nnq.ConvTranspose2d,
- ]),
- # ELU
- 'torch.nn.ELU': set([
- nn.ELU,
- nnq.ELU,
- ]),
- # Embedding
- 'torch.nn.Embedding': set([
- nn.Embedding,
- nnq.Embedding,
- ]),
- # EmbeddingBag
- 'torch.nn.EmbeddingBag': set([
- nn.EmbeddingBag,
- nnq.EmbeddingBag,
- ]),
- # GroupNorm
- 'torch.nn.GroupNorm': set([
- nn.GroupNorm,
- nnq.GroupNorm,
- ]),
- # Hardswish
- 'torch.nn.Hardswish': set([
- nn.Hardswish,
- nnq.Hardswish,
- ]),
- # InstanceNorm
- 'torch.nn.InstanceNorm1d': set([
- nn.InstanceNorm1d,
- nnq.InstanceNorm1d,
- ]),
- 'torch.nn.InstanceNorm2d': set([
- nn.InstanceNorm2d,
- nnq.InstanceNorm2d,
- ]),
- 'torch.nn.InstanceNorm3d': set([
- nn.InstanceNorm3d,
- nnq.InstanceNorm3d,
- ]),
- # LayerNorm
- 'torch.nn.LayerNorm': set([
- nn.LayerNorm,
- nnq.LayerNorm,
- ]),
- # LeakyReLU
- 'torch.nn.LeakyReLU': set([
- nn.LeakyReLU,
- nnq.LeakyReLU,
- ]),
- # ReLU6
- 'torch.nn.ReLU6': set([
- nn.ReLU6,
- nnq.ReLU6,
- ]),
- # BNReLU2d
- 'torch.nn.intrinsic.BNReLU2d': set([
- nni.BNReLU2d,
- nniq.BNReLU2d,
- ]),
- 'torch.nn.intrinsic.BNReLU3d': set([
- nni.BNReLU3d,
- nniq.BNReLU3d,
- ]),
- # F.elu
- 'torch.nn.functional.elu': set([
- F.elu,
- toq.elu,
- ]),
- # F.hardswish
- 'torch.nn.functional.hardswish': set([
- F.hardswish,
- toq.hardswish,
- ]),
- # F.instance_norm
- 'torch.nn.functional.instance_norm': set([
- F.instance_norm,
- toq.instance_norm,
- ]),
- # F.layer_norm
- 'torch.nn.functional.layer_norm': set([
- F.layer_norm,
- toq.layer_norm,
- ]),
- # F.leaky_relu
- 'torch.nn.functional.leaky_relu': set([
- F.leaky_relu,
- toq.leaky_relu,
- ]),
- }
- return base_name_to_sets_of_related_ops
-
def get_type_a_related_to_b(
base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],