ns for fx: skip shadowing for torch.cat, and also for nodes with only kwargs (#76561)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76561

User model had syntax like `torch.cat(tensors=[x])`. This PR fixes two errors
to unbreak this in NS shadow model:
1. skip nodes which only have kwargs (instead of throwing an exception)
2. explicitly skip shadowing of `torch.cat` (since it's not supported anyways)

Test Plan:
```
python test/test_quantization.py -k test_op_with_only_kwargs_skips_shadowing
python test/test_quantization.py -k test_op_mul_add_cat_skips_shadowing
```

Reviewed By: hx89

Differential Revision: D36017356

Pulled By: vkuzo

fbshipit-source-id: 0da4840a62c2dac183f8294c2cec4fce262474b3
(cherry picked from commit 88409c1576e7f690708957b2baa285fc7961e9d6)
diff --git a/test/quantization/fx/test_numeric_suite_fx.py b/test/quantization/fx/test_numeric_suite_fx.py
index b60f1cb..d939e7d 100644
--- a/test/quantization/fx/test_numeric_suite_fx.py
+++ b/test/quantization/fx/test_numeric_suite_fx.py
@@ -1172,33 +1172,6 @@
             results_len=0)
 
     @skipIfNoFBGEMM
-    def test_add_shadow_loggers_multiple_dtype_casts(self):
-        """
-        Verifies that for nodes where the first input arg is a list,
-        such as `cat`, we insert an individual dtype cast for each
-        arg of the list.
-        """
-        class M(nn.Module):
-            def __init__(self):
-                super().__init__()
-
-            def forward(self, x):
-                x = torch.cat([x, x, x], dim=0)
-                return x
-
-        m = M().eval()
-        expected_occurrence = {
-            # 3 dequantize function calls from the 3 dtype casts for [x, x, x]
-            ns.call_module(torch.nn.Identity): 3,
-            # 1 dequantize method call for module output
-            ns.call_method("dequantize"): 1,
-        }
-        self._test_match_shadow_activations(
-            m, (torch.randn(4, 4),),
-            prepared_expected_node_occurrence=expected_occurrence,
-            results_len=1, compare_fp32_vs_fp32_prepared=False)
-
-    @skipIfNoFBGEMM
     def test_shadow_activations_fqn(self):
         m = nn.Sequential(
             nn.Sequential(nn.Conv2d(1, 1, 1)),
@@ -1237,7 +1210,7 @@
         m = M().eval()
         self._test_match_shadow_activations(
             m, (torch.randn(1, 1, 4, 4),),
-            results_len=2,
+            results_len=1,
             should_log_inputs=True)
 
     @skipIfNoFBGEMM
@@ -1954,13 +1927,27 @@
         mq = convert_fx(mp, is_reference=True)
         mq_shadows_m = add_shadow_loggers('a', mq, 'b', m, OutputLogger)
 
-    def test_mul_add_skips_shadowing(self):
+    def test_mul_add_cat_stack_skips_shadowing(self):
         class M(nn.Module):
             def forward(self, x):
                 x = x * x
                 x = torch.mul(x, x)
                 x = x + x
                 x = torch.add(x, x)
+                x = torch.cat([x])
+                x = torch.stack([x])
+                return x
+
+        m = M().eval()
+        self._test_match_shadow_activations(
+            m, (torch.randn(1, 1, 4, 4),),
+            results_len=0)
+
+    def test_op_with_only_kwargs_skips_shadowing(self):
+        class M(nn.Module):
+            def forward(self, x):
+                x = torch.cat(tensors=[x])
+                x = torch.stack(tensors=[x])
                 return x
 
         m = M().eval()
@@ -2101,7 +2088,7 @@
             x = torch.randn(2, 4)
             self._test_match_shadow_activations(
                 sparse_nn, (idx, offsets, x),
-                results_len=4,
+                results_len=3,
                 should_log_inputs=should_log_inputs)
 
     @skip_if_no_torchvision
diff --git a/torch/ao/ns/fx/graph_passes.py b/torch/ao/ns/fx/graph_passes.py
index 98a0228..aa99ef7 100644
--- a/torch/ao/ns/fx/graph_passes.py
+++ b/torch/ao/ns/fx/graph_passes.py
@@ -602,6 +602,26 @@
                 subgraph_a, ref_name, ref_node_type_a, ref_node_type_b = \
                     end_node_b_to_matched_subgraph_a_and_name[node_b]
 
+            if len(node_b.args) == 0:
+                print(
+                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+                    ', kwargs-only node not handled yet')
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
+            all_op_types_support_shadowing = (
+                op_type_supports_shadowing(subgraph_a.start_node) and
+                op_type_supports_shadowing(node_b)
+            )
+            if not all_op_types_support_shadowing:
+                print(
+                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
+                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
+                    ', unsupported')
+                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
+                continue
+
             # For both start_node and end_node verify that we know how to do
             # the dtype cast. If we do not, skip.
             node_input_type_a, node_output_type_a = \
@@ -626,18 +646,6 @@
                 env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
                 continue
 
-            all_op_types_support_shadowing = (
-                op_type_supports_shadowing(subgraph_a.start_node) and
-                op_type_supports_shadowing(node_b)
-            )
-            if not all_op_types_support_shadowing:
-                print(
-                    f'skipping shadow loggers for node_b: {get_target_type_str(node_b, gm_b)}' +
-                    f', start_node_a: {get_target_type_str(subgraph_a.start_node, gm_a)}' +
-                    ', unsupported')
-                env_c[node_b.name] = graph_c.node_copy(node_b, load_arg)
-                continue
-
             # If we are shadowing from fp32 to int8, we need to insert
             # quantize_per_tensor call with qparams from the previous node.
             # Only do this if we are able to infer these qparams from the graph.
diff --git a/torch/ao/ns/fx/utils.py b/torch/ao/ns/fx/utils.py
index f13d303..8f1f277 100644
--- a/torch/ao/ns/fx/utils.py
+++ b/torch/ao/ns/fx/utils.py
@@ -492,7 +492,7 @@
 
 def op_type_supports_shadowing(node: Node) -> bool:
     if node.op == 'call_function':
-        if node.target in (torch.add, torch.mul, operator.add, operator.mul):
-            # shadowing for ops with two inputs is not implemented yet
+        if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack):
+            # shadowing for ops with multiple tensor inputs is not implemented yet
             return False
     return True