[torch][fx] Add ignore_parameters_and_buffers kwarg to FxGraphDrawer (#79982)
Summary:
Add an `ignore_parameters_and_buffers` parameter which will tell the graph drawer
to leave off adding parameter and buffer nodes in the dot graph.
This is useful for large networks, where we want to view the graph to get an idea of
the topology and the shapes without needing to see every detail. Removing these buffers
de-clutters the graph significantly without detracting much information.
Reviewed By: jfix71
Differential Revision: D37317917
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79982
Approved by: https://github.com/jfix71
diff --git a/torch/fx/passes/graph_drawer.py b/torch/fx/passes/graph_drawer.py
index 045f019..6f6d7b7 100644
--- a/torch/fx/passes/graph_drawer.py
+++ b/torch/fx/passes/graph_drawer.py
@@ -65,12 +65,13 @@
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool = False,
+ ignore_parameters_and_buffers: bool = False,
skip_node_names_in_args: bool = True,
):
self._name = name
self._dot_graphs = {
name: self._to_dot(
- graph_module, name, ignore_getattr, skip_node_names_in_args
+ graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args
)
}
@@ -87,6 +88,7 @@
leaf_node,
f"{name}_{node.target}",
ignore_getattr,
+ ignore_parameters_and_buffers,
skip_node_names_in_args,
)
@@ -258,10 +260,13 @@
graph_module: torch.fx.GraphModule,
name: str,
ignore_getattr: bool,
+ ignore_parameters_and_buffers: bool,
skip_node_names_in_args: bool,
) -> pydot.Dot:
"""
- Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph
+ Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph.
+ If ignore_parameters_and_buffers is True, the parameters and buffers
+ created with the module will not be added as nodes and edges.
"""
dot_graph = pydot.Dot(name, rankdir="TB")
@@ -296,7 +301,7 @@
if node.op == "call_module":
leaf_module = self._get_leaf_node(graph_module, node)
- if not isinstance(leaf_module, torch.fx.GraphModule):
+ if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule):
get_module_params_or_buffers()
for node in graph_module.graph.nodes: