blob: ab107716208dc0e47dc8ab4477447c02bb4c5044 [file] [log] [blame]
# Owner(s): ["module: fx"]
import copy
import unittest
from typing import Set, Type
import torch
import torch.fx
from torch.testing._internal.common_utils import IS_MACOS, TestCase
class TestDCE(TestCase):
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
if node.is_impure():
return True
# a custom function that defines add operators as impure.
if node.target == torch.ops.aten.add:
return True
return False
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
for node in m.graph.nodes:
if (not custom and node.is_impure()) or (
custom and self._custom_is_impure_node(node)
):
continue
if len(node.users) == 0:
return True
return False
def _get_num_placeholders(self, m: torch.fx.GraphModule) -> int:
count = 0
for node in m.graph.nodes:
if node.op == "placeholder":
count += 1
return count
def _run_dce_and_test(
self,
m: torch.nn.Module,
expect_dce_changes: bool,
modules_to_be_leafs: Set[Type] = None,
custom: bool = False,
):
class TestTracer(torch.fx.Tracer):
def is_leaf_module(self, m, qualname):
if modules_to_be_leafs and type(m) in modules_to_be_leafs:
return True
return super().trace(m, qualname)
traced: torch.fx.GraphModule = torch.fx.GraphModule(m, TestTracer().trace(m))
print(str(traced.graph))
# Verify there are nodes without users (if expected).
has_nodes_without_users = self._has_nodes_without_users(traced, custom=custom)
if expect_dce_changes:
self.assertTrue(has_nodes_without_users)
else:
self.assertFalse(has_nodes_without_users)
# Get the original number of placeholders to verify it doesn't change
# during DCE.
orig_num_phs = self._get_num_placeholders(traced)
if custom:
changed = traced.graph.eliminate_dead_code(
is_impure_node=self._custom_is_impure_node
)
else:
changed = traced.graph.eliminate_dead_code()
self.assertTrue(changed if expect_dce_changes else not changed)
# Verify there are no nodes without users after DCE is run.
self.assertFalse(self._has_nodes_without_users(traced, custom=custom))
new_num_phs = self._get_num_placeholders(traced)
self.assertEqual(orig_num_phs, new_num_phs)
traced.recompile()
# Make sure we run and get the same results before/after DCE.
inputs = [torch.tensor([1.5])] * new_num_phs
inputs_copy = copy.deepcopy(inputs)
self.assertTrue(torch.equal(m(*inputs), traced(*inputs_copy)))
def test_simple(self):
"""
Tests that a single node in the graph is DCE'd correctly.
"""
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
def forward(self, x):
a = x + 1
return x + self.attr_1
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
def test_dead_chain(self):
"""
Tests that a chain of two nodes in the graph are DCE'd correctly.
"""
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
def forward(self, x):
a = x + 1
b = a * 7
return x + self.attr_1
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
def test_dead_getattr(self):
"""
Tests that a getatrr in the graph is DCE'd correctly.
"""
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.attr_1 = torch.nn.Parameter(torch.tensor([-0.9]))
def forward(self, x):
a = x + 1
b = a * self.attr_1
return x + 11
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
def test_dead_placeholder(self):
"""
Tests that a placeholder in the graph is not DCE'd, as that would change
the function signature.
"""
class TestModule(torch.nn.Module):
def forward(self, x, y):
return x + 7
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_dead_placeholder_with_user(self):
"""
Tests that a placeholder in the graph is not DCE'd, as that would change
the function signature. Also verifies that a dead node that uses the
placeholder is DCE'd.
"""
class TestModule(torch.nn.Module):
def forward(self, x, y):
a = y + 2
return x + 7
self._run_dce_and_test(TestModule(), expect_dce_changes=True)
def test_keep_module_with_side_effects(self):
"""
Test that DCE doesn't remove a module if it's specified as having side effects.
"""
class ReLUImpure(torch.nn.ReLU):
_is_impure = True
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = ReLUImpure()
def forward(self, a: torch.Tensor) -> torch.Tensor:
r = self.relu(a)
return a * 2
self._run_dce_and_test(
TestModule(), expect_dce_changes=False, modules_to_be_leafs={ReLUImpure}
)
def test_keep_torch_assert(self):
"""
Test that DCE doesn't remove torch._assert since it has side effects.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
torch._assert(torch.equal(a, a), "a must equal a")
return a * 2
# Note: Don't need to specify torch._assert as having side effects
# because it's known to.
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_impure_nodes_args(self):
"""
Test that DCE doesn't remove call_function nodes with side effects.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
torch._ops.ops.aten.add_.Tensor(a, 1)
return a * 2
# %add_ node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_impure_kwargs(self):
"""
Test that DCE doesn't remove call_function nodes with side effects on kwargs.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
b = a + 1
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
return a
# %add_out node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_impure_custom(self):
"""
Test that DCE doesn't remove nodes marked as impure by a custom function.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
b = a + 1
c = torch._ops.ops.aten.add(b, b)
return a
# %add_out node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
@unittest.skipIf(IS_MACOS, "Not working on macos")
def test_keep_collectives(self):
"""
Test that DCE doesn't remote collective ops even the results are not used.
"""
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestModule(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> torch.Tensor:
d = torch.ops.aten.mul.Tensor(a, b)
e = torch.ops.aten.mul.Tensor(a, c)
future = torch.ops._c10d_functional.all_reduce.default(e, "sum", "0")
synced_e = torch.ops._c10d_functional.wait_tensor.default(
future
) # synced_e is not used
return d
torch.distributed.init_process_group(
backend="fake",
world_size=2,
rank=0,
store=FakeStore(),
)
# collective nodes should not be removed because they have side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
torch.distributed.destroy_process_group()
@unittest.skipIf(IS_MACOS, "Not working on macos")
def test_keep_collectives_no_overload(self):
"""
Test that DCE doesn't remote collective ops (no overload version) even the results are not used.
"""
from torch.testing._internal.distributed.fake_pg import FakeStore
class TestModule(torch.nn.Module):
def forward(
self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor
) -> torch.Tensor:
d = torch.ops.aten.mul(a, b)
e = torch.ops.aten.mul(a, c)
future = torch.ops._c10d_functional.all_reduce(e, "sum", "0")
synced_e = torch.ops._c10d_functional.wait_tensor(
future
) # synced_e is not used
return d
torch.distributed.init_process_group(
backend="fake",
world_size=2,
rank=0,
store=FakeStore(),
)
# collective nodes should not be removed because they have side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=False)
torch.distributed.destroy_process_group()