ns for fx: skip shadowing for torch.cat, and also for nodes with only kwargs (#76561)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76561
User model had syntax like `torch.cat(tensors=[x])`. This PR fixes two errors
to unbreak this in NS shadow model:
1. skip nodes which only have kwargs (instead of throwing an exception)
2. explicitly skip shadowing of `torch.cat` (since it's not supported anyways)
Test Plan:
```
python test/test_quantization.py -k test_op_with_only_kwargs_skips_shadowing
python test/test_quantization.py -k test_op_mul_add_cat_skips_shadowing
```
Reviewed By: hx89
Differential Revision: D36017356
Pulled By: vkuzo
fbshipit-source-id: 0da4840a62c2dac183f8294c2cec4fce262474b3
(cherry picked from commit 88409c1576e7f690708957b2baa285fc7961e9d6)
diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py
index b60f1cb..d939e7d 100644
--- a/test/quantization/fx/test_numeric_suite_fx.py
+++ b/test/quantization/fx/test_numeric_suite_fx.py
@@ -1172,33 +1172,6 @@
results_len=0)
@skipIfNoFBGEMM
- def test_add_shadow_loggers_multiple_dtype_casts(self):
- """
- Verifies that for nodes where the first input arg is a list,
- such as `cat`, we insert an individual dtype cast for each
- arg of the list.
- """
- class M(nn.Module):
- def __init__(self):
- super().__init__()
-
- def forward(self, x):
- x = torch.cat([x, x, x], dim=0)
- return x
-
- m = M().eval()
- expected_occurrence = {
- # 3 dequantize function calls from the 3 dtype casts for [x, x, x]
- ns.call_module(torch.nn.Identity): 3,
- # 1 dequantize method call for module output
- ns.call_method("dequantize"): 1,
- }
- self._test_match_shadow_activations(
- m, (torch.randn(4, 4),),
- prepared_expected_node_occurrence=expected_occurrence,
- results_len=1, compare_fp32_vs_fp32_prepared=False)
-
- @skipIfNoFBGEMM
def test_shadow_activations_fqn(self):
m = nn.Sequential(
nn.Sequential(nn.Conv2d(1, 1, 1)),
@@ -1237,7 +1210,7 @@
m = M().eval()
self._test_match_shadow_activations(
m, (torch.randn(1, 1, 4, 4),),
- results_len=2,
+ results_len=1,
should_log_inputs=True)
@skipIfNoFBGEMM
@@ -1954,13 +1927,27 @@
mq = convert_fx(mp, is_reference=True)
mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger)
- def test_mul_add_skips_shadowing(self):
+ def test_mul_add_cat_stack_skips_shadowing(self):
class M(nn.Module):
def forward(self, x):
x = x * x
x = torch.mul(x, x)
x = x + x
x = torch.add(x, x)
+ x = torch.cat([x])
+ x = torch.stack([x])
+ return x
+
+ m = M().eval()
+ self._test_match_shadow_activations(
+ m, (torch.randn(1, 1, 4, 4),),
+ results_len=0)
+
+ def test_op_with_only_kwargs_skips_shadowing(self):
+ class M(nn.Module):
+ def forward(self, x):
+ x = torch.cat(tensors=[x])
+ x = torch.stack(tensors=[x])
return x
m = M().eval()
@@ -2101,7 +2088,7 @@
x = torch.randn(2, 4)
self._test_match_shadow_activations(
sparse_nn, (idx, offsets, x),
- results_len=4,
+ results_len=3,
should_log_inputs=should_log_inputs)
@skip_if_no_torchvision
diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py
index 98a0228..aa99ef7 100644
--- a/torch/ao/ns/fx/graph_passes.py
+++ b/torch/ao/ns/fx/graph_passes.py
@@ -602,6 +602,26 @@
subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
end_node_b_to_matched_subgraph_a_and_name[node_b]
+ if len(node_b.args) == 0:
+ print(
+ f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+ f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+ ', kwargs-only node not handled yet')
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+ continue
+
+ all_op_types_support_shadowing = (
+ op_type_supports_shadowing(subgraph_a.start_node) and
+ op_type_supports_shadowing(node_b)
+ )
+ if not all_op_types_support_shadowing:
+ print(
+ f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+ f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+ ', unsupported')
+ env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+ continue
+
# For both start_node and end_node verify that we know how to do
# the dtype cast. If we do not, skip.
node_input_type_a, node_output_type_a = \
@@ -626,18 +646,6 @@
env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
continue
- all_op_types_support_shadowing = (
- op_type_supports_shadowing(subgraph_a.start_node) and
- op_type_supports_shadowing(node_b)
- )
- if not all_op_types_support_shadowing:
- print(
- f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
- f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
- ', unsupported')
- env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
- continue
-
# If we are shadowing from fp32 to int8, we need to insert
# quantize_per_tensor call with qparams from the previous node.
# Only do this if we are able to infer these qparams from the graph.
diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py
index f13d303..8f1f277 100644
--- a/torch/ao/ns/fx/utils.py
+++ b/torch/ao/ns/fx/utils.py
@@ -492,7 +492,7 @@
def op_type_supports_shadowing(node: Node) -> bool:
if node.op == 'call_function':
- if node.target in (torch.add, torch.mul, operator.add, operator.mul):
- # shadowing for ops with two inputs is not implemented yet
+ if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
+ # shadowing for ops with multiple tensor inputs is not implemented yet
return False
return True