[fx] Add DCE pass (#52658)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52658
DCE will reverse iterate over the graph looking for nodes without users and delete them. It will skip over unused placeholders (since this affects the signature of the method) and outputs (which never have users but we want to keep them :) )
Test Plan: Added unit tests
Reviewed By: jamesr66a, khabinov, chenccfb
Differential Revision: D26602212
fbshipit-source-id: f4f196973e40546076636090bb0008c24f33795e
diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py
new file mode 100644
index 0000000..9379de5
--- /dev/null
+++ b/test/fx/test_dce_pass.py
@@ -0,0 +1,182 @@
+import unittest
+
+from typing import Set, Type
+import torch
+import torch.fx
+
+
+class TestDCE(unittest.TestCase):
+ def _has_nodes_without_users(self, m: torch.fx.GraphModule):
+ for node in m.graph.nodes:
+ if node.is_impure():
+ continue
+ if len(node.users) == 0:
+ return True
+ return False
+
+ def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
+ count = 0
+ for node in m.graph.nodes:
+ if node.op == "placeholder":
+ count += 1
+ return count
+
+ def _run_dce_and_test(
+ self,
+ m: torch.nn.Module,
+ expect_dce_changes: bool,
+ modules_to_be_leafs: Set[Type] = None,
+ ):
+ class TestTracer(torch.fx.Tracer):
+ def is_leaf_module(self, m, qualname):
+ if modules_to_be_leafs and type(m) in modules_to_be_leafs:
+ return True
+ return super().trace(m, qualname)
+
+ traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
+ print(str(traced.graph))
+
+ # Verify there are nodes without users (if expected).
+ has_nodes_without_users = self._has_nodes_without_users(traced)
+ if expect_dce_changes:
+ self.assertTrue(has_nodes_without_users)
+ else:
+ self.assertFalse(has_nodes_without_users)
+
+ # Get the original number of placeholders to verify it doesn't change
+ # during DCE.
+ orig_num_phs = self._get_num_placeholders(traced)
+ changed = traced.graph.eliminate_dead_code()
+
+ self.assertTrue(changed if expect_dce_changes else not changed)
+
+ # Verify there are no nodes without users after DCE is run.
+ self.assertFalse(self._has_nodes_without_users(traced))
+ new_num_phs = self._get_num_placeholders(traced)
+ self.assertEqual(orig_num_phs, new_num_phs)
+
+ traced.recompile()
+ # Make sure we run and get the same results before/after DCE.
+ inputs = [torch.tensor([1.5])] * new_num_phs
+ self.assertTrue(torch.equal(m(*inputs), traced(*inputs)))
+
+ def test_simple(self):
+ """
+ Tests that a single node in the graph is DCE'd correctly.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
+
+ def forward(self, x):
+ a = x + 1
+ return x + self.attr_1
+
+ self._run_dce_and_test(TestModule(), expect_dce_changes=True)
+
+ def test_dead_chain(self):
+ """
+ Tests that a chain of two nodes in the graph are DCE'd correctly.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
+
+ def forward(self, x):
+ a = x + 1
+ b = a * 7
+ return x + self.attr_1
+
+ self._run_dce_and_test(TestModule(), expect_dce_changes=True)
+
+ def test_dead_getattr(self):
+ """
+ Tests that a getatrr in the graph is DCE'd correctly.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
+
+ def forward(self, x):
+ a = x + 1
+ b = a * self.attr_1
+ return x + 11
+
+ self._run_dce_and_test(TestModule(), expect_dce_changes=True)
+
+ def test_dead_placeholder(self):
+ """
+ Tests that a placeholder in the graph is not DCE'd, as that would change
+ the function signature.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ return x + 7
+
+ self._run_dce_and_test(TestModule(), expect_dce_changes=False)
+
+ def test_dead_placeholder_with_user(self):
+ """
+ Tests that a placeholder in the graph is not DCE'd, as that would change
+ the function signature. Also verifies that a dead node that uses the
+ placeholder is DCE'd.
+
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x, y):
+ a = y + 2
+ return x + 7
+
+ self._run_dce_and_test(TestModule(), expect_dce_changes=True)
+
+ def test_keep_module_with_side_effects(self):
+ """
+ Test that DCE doesn't remove a module if it's specified as having side effects.
+ """
+
+ class ReLUImpure(torch.nn.ReLU):
+ _is_impure = True
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.relu = ReLUImpure()
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+ r = self.relu(a)
+ return a * 2
+
+ self._run_dce_and_test(
+ TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
+ )
+
+ def test_keep_torch_assert(self):
+ """
+ Test that DCE doesn't remove torch._assert since it has side effects.
+ """
+
+ class TestModule(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, a: torch.Tensor) -> torch.Tensor:
+ torch._assert(torch.equal(a, a), "a must equal a")
+ return a * 2
+
+ # Note: Don't need to specify torch._assert as having side effects
+ # because it's known to.
+ self._run_dce_and_test(TestModule(), expect_dce_changes=False)
diff --git a/test/test_fx.py b/test/test_fx.py
index 3e88fc8..dc12868 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -27,6 +27,7 @@
from fx.quantization import Quantizer
from fx.test_subgraph_rewriter import TestSubgraphRewriter # noqa: F401
+from fx.test_dce_pass import TestDCE # noqa: F401
from typing import Any, Callable, Dict, NamedTuple, List, Optional, Tuple, Union
from torch.testing._internal.common_utils import run_tests, TEST_WITH_ROCM, IS_WINDOWS, IS_SANDCASTLE, IS_MACOS
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index b6bffaf..ced2b0a 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -1004,6 +1004,50 @@
raise RuntimeError(f'Node {node} target {node.target} references nonexistent attribute '
f'{atom} of {seen_qualname}')
+ def eliminate_dead_code(self):
+ """
+ Remove all dead code from the graph, based on each node's number of
+ users, and whether the nodes have any side effects The graph must be
+ topologically sorted before calling.
+
+ Returns:
+ bool: Whether the graph was changed as a result of the pass.
+
+ Example:
+
+ Before dead code is eliminated, `a` from `a = x + 1` below has no users
+ and thus can be eliminated from the graph without having an effect.
+
+ .. code-block:: python
+
+ def forward(self, x):
+ a = x + 1
+ return x + self.attr_1
+
+ After dead code is eliminated, `a = x + 1` has been removed, and the rest
+ of `forward` remains.
+
+ .. code-block:: python
+
+ def forward(self, x):
+ return x + self.attr_1
+
+ """
+ # Lint the graph first to make sure its topologically sorted, otherwise
+ # DCE below will not behave as expected.
+ self.lint()
+
+ # Reverse iterate so that when we remove a node, any nodes used as an
+ # input to that node have an updated user count that no longer reflects
+ # the removed node.
+ changed = False
+ for node in reversed(self.nodes):
+ if not node.is_impure() and len(node.users) == 0:
+ self.erase_node(node)
+ changed = True
+
+ return changed
+
reflectable_magic_methods = {
'add': '{} + {}',
diff --git a/torch/fx/node.py b/torch/fx/node.py
index 51290ff..6297c95 100644
--- a/torch/fx/node.py
+++ b/torch/fx/node.py
@@ -1,5 +1,5 @@
# Nodes represent a definition of a value in our graph of operators.
-from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict
+from typing import TYPE_CHECKING, Union, Callable, Any, Tuple, List, Optional, Dict, Set
from .immutable_collections import immutable_dict, immutable_list
import torch
import builtins
@@ -22,6 +22,8 @@
BaseArgumentTypes
]]
+_side_effectful_functions: Set[Callable] = {torch._assert}
+
# this is fixed on master, WAR for 1.5
def _find_module_of_method(orig_method: Callable[..., Any]) -> str:
name = orig_method.__name__
@@ -408,6 +410,35 @@
assert len(self.users) == 0
return to_process
+ def is_impure(self):
+ """
+ Returns whether this op is impure, i.e. if its op is a placeholder or
+ output, or if a call_function or call_module which is impure.
+
+ Returns:
+
+ bool: If the op is impure or not.
+ """
+ if self.op in {"placeholder", "output"}:
+ return True
+
+ # Check if an impure function.
+ if self.op == "call_function":
+ return self.target in _side_effectful_functions
+
+ # Check if an impure module.
+ if self.op == "call_module":
+ assert (
+ self.graph.owning_module is not None
+ ), "self.graph.owning_module not set for purity check"
+ target_mod = self.graph.owning_module.get_submodule(self.target)
+ assert (
+ target_mod is not None
+ ), f"Did not find expected submodule target {self.target}"
+ return getattr(target_mod, "_is_impure", False)
+
+ return False
+
def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:
""" Apply fn to each Node appearing arg. arg may be a list, tuple, slice, or dict with string keys. """
assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable"