Fix fx2trt SplitterBase non_tensor_input logic (#64286)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64286
During graph splitting, `_SplitterBase` supports taking into consideration whether the subnet boundary nodes
produces "supported" outputs that will cross the acc/non-acc boundary. Specifically, if the backend only
supports Tensor-based data passing cross boundary, then we cannot split the graph at a place where the node
output is a non-Tensor type (e.g., `Tuple[Tensor]`).
There's currently a bug in this logic that it does not correctly detect the output type of a Node. Instead of
using `Node.meta['tensor_meta']`, we should instead check `Node.meta['type']`.
`Node.meta['tensor_meta']` is not appropriate because this key will exist if the node output is an iterable
and one of the element is of type `Tensor`. So `Tuple[Tensor]` will be wrongly considered "supported".
Test Plan:
arc lint
run CI tests
Reviewed By: yinghai, 842974287
Differential Revision: D30617147
fbshipit-source-id: e8ba70dfaddc05cafb8037d58fca73b7ccbb1a49
diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py
index 42087bd..6541905 100644
--- a/torch/fx/passes/splitter_base.py
+++ b/torch/fx/passes/splitter_base.py
@@ -2,6 +2,7 @@
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Dict, Optional, Tuple
+import logging
import torch
from torch.fx.experimental.graph_manipulation import get_size_of_node
@@ -20,8 +21,12 @@
Tensors,
NodeList,
NodeSet,
+ is_node_output_tensor,
)
+_LOGGER = logging.getLogger(__name__)
+
+
class _SplitterSettingBase:
def __init__(self):
parser = argparse.ArgumentParser()
@@ -98,7 +103,7 @@
for user in node.users:
if user in self.acc_nodes:
self.acc_nodes.remove(user)
- if "tensor_meta" not in user.meta:
+ if not is_node_output_tensor(user):
cpu_worklist.append(user)
def reduce_acc_nodes_non_tensor_input(self):
@@ -113,7 +118,7 @@
continue
if node in self.acc_nodes:
continue
- if "tensor_meta" in node.meta:
+ if is_node_output_tensor(node):
continue
non_tensor_cpu_nodes.append(node)
@@ -128,7 +133,7 @@
new_cpu_nodes: NodeList = []
for acc_node in self.acc_nodes:
- if "tensor_meta" in acc_node.meta:
+ if is_node_output_tensor(acc_node):
continue
for user in acc_node.users:
if user not in self.acc_nodes:
@@ -461,7 +466,7 @@
reports += "Checking inputs...\n"
for n in submod.graph.nodes:
if n.op == "placeholder":
- if "tensor_meta" not in n.meta:
+ if not is_node_output_tensor(n):
reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
else:
total_input_bytes += get_size_of_node(submod, n)[0]
@@ -473,7 +478,7 @@
def get_bytes(node: torch.fx.Node):
nonlocal total_output_bytes
nonlocal reports
- if "tensor_meta" not in node.meta:
+ if not is_node_output_tensor(node):
reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
else:
total_output_bytes += get_size_of_node(submod, node)[0]
diff --git a/torch/fx/passes/tools_common.py b/torch/fx/passes/tools_common.py
index a996dc8..8274f4b 100644
--- a/torch/fx/passes/tools_common.py
+++ b/torch/fx/passes/tools_common.py
@@ -48,6 +48,17 @@
return node.target
+def is_node_output_tensor(node: torch.fx.Node) -> bool:
+ """Checks if the node output produces a Tensor or not.
+
+ NOTE: This requires to run `ShapeProp` on the containing fx graph before
+ calling this function. This is because it works by checking the `type`
+ metadata on the node. This metadata is produced by the `ShapeProp`.
+ """
+ type_ = node.meta.get("type", None)
+ return type_ is not None and issubclass(type_, torch.Tensor)
+
+
class FxNetAccFusionsFinder:
"""
Finds groups of connected ACC nodes that pass non-tensor data between each other.