Fix fx2trt SplitterBase non_tensor_input logic (#64286)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64286

During graph splitting, `_SplitterBase` supports taking into consideration whether the subnet boundary nodes
produces "supported" outputs that will cross the acc/non-acc boundary. Specifically, if the backend only
supports Tensor-based data passing cross boundary, then we cannot split the graph at a place where the node
output is a non-Tensor type (e.g., `Tuple[Tensor]`).

There's currently a bug in this logic that it does not correctly detect the output type of a Node. Instead of
using `Node.meta['tensor_meta']`, we should instead check `Node.meta['type']`.

`Node.meta['tensor_meta']` is not appropriate because this key will exist if the node output is an iterable
and one of the element is of type `Tensor`. So `Tuple[Tensor]` will be wrongly considered "supported".

Test Plan:
arc lint
run CI tests

Reviewed By: yinghai, 842974287

Differential Revision: D30617147

fbshipit-source-id: e8ba70dfaddc05cafb8037d58fca73b7ccbb1a49
diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py
index 42087bd..6541905 100644
--- a/torch/fx/passes/splitter_base.py
+++ b/torch/fx/passes/splitter_base.py
@@ -2,6 +2,7 @@
 from collections import defaultdict
 from dataclasses import dataclass
 from typing import List, Dict, Optional, Tuple
+import logging
 
 import torch
 from torch.fx.experimental.graph_manipulation import get_size_of_node
@@ -20,8 +21,12 @@
     Tensors,
     NodeList,
     NodeSet,
+    is_node_output_tensor,
 )
 
+_LOGGER = logging.getLogger(__name__)
+
+
 class _SplitterSettingBase:
     def __init__(self):
         parser = argparse.ArgumentParser()
@@ -98,7 +103,7 @@
             for user in node.users:
                 if user in self.acc_nodes:
                     self.acc_nodes.remove(user)
-                    if "tensor_meta" not in user.meta:
+                    if not is_node_output_tensor(user):
                         cpu_worklist.append(user)
 
     def reduce_acc_nodes_non_tensor_input(self):
@@ -113,7 +118,7 @@
                 continue
             if node in self.acc_nodes:
                 continue
-            if "tensor_meta" in node.meta:
+            if is_node_output_tensor(node):
                 continue
             non_tensor_cpu_nodes.append(node)
 
@@ -128,7 +133,7 @@
             new_cpu_nodes: NodeList = []
 
             for acc_node in self.acc_nodes:
-                if "tensor_meta" in acc_node.meta:
+                if is_node_output_tensor(acc_node):
                     continue
                 for user in acc_node.users:
                     if user not in self.acc_nodes:
@@ -461,7 +466,7 @@
                 reports += "Checking inputs...\n"
                 for n in submod.graph.nodes:
                     if n.op == "placeholder":
-                        if "tensor_meta" not in n.meta:
+                        if not is_node_output_tensor(n):
                             reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
                         else:
                             total_input_bytes += get_size_of_node(submod, n)[0]
@@ -473,7 +478,7 @@
                 def get_bytes(node: torch.fx.Node):
                     nonlocal total_output_bytes
                     nonlocal reports
-                    if "tensor_meta" not in node.meta:
+                    if not is_node_output_tensor(node):
                         reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
                     else:
                         total_output_bytes += get_size_of_node(submod, node)[0]
diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py
index a996dc8..8274f4b 100644
--- a/torch/fx/passes/tools_common.py
+++ b/torch/fx/passes/tools_common.py
@@ -48,6 +48,17 @@
         return node.target
 
 
+def is_node_output_tensor(node: torch.fx.Node) -> bool:
+    """Checks if the node output produces a Tensor or not.
+
+    NOTE: This requires to run `ShapeProp` on the containing fx graph before
+    calling this function. This is because it works by checking the `type`
+    metadata on the node. This metadata is produced by the `ShapeProp`.
+    """
+    type_ = node.meta.get("type", None)
+    return type_ is not None and issubclass(type_, torch.Tensor)
+
+
 class FxNetAccFusionsFinder:
     """
     Finds groups of connected ACC nodes that pass non-tensor data between each other.