| # Owner(s): ["module: fx"] |
| |
| import torch |
| import torch.fx as fx |
| from torch.fx.passes.infra.pass_base import PassBase, PassResult |
| from torch.fx.passes.infra.pass_manager import ( |
| _topological_sort_passes, |
| pass_result_wrapper, |
| PassManager, |
| this_before_that_pass_constraint, |
| ) |
| |
| from torch.testing._internal.common_utils import TestCase |
| |
| |
| # Pass that uses PassBase and returns a PassResult (best scenario) |
| class ReplaceAddWithMulPass(PassBase): |
| def call(self, gm) -> PassResult: |
| modified = False |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.add: |
| node.target = torch.mul |
| modified = True |
| return PassResult(gm, modified) |
| |
| |
| # Pass that is a callable and returns a PassResult |
| def replace_mul_with_div_pass(gm) -> PassResult: |
| modified = False |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.mul: |
| node.target = torch.div |
| modified = True |
| return PassResult(gm, modified) |
| |
| |
| # Pass that is a PassBase and does not return a PassResult |
| # Need to wrap with pass_result_wrapper or else it will fail |
| class ReplaceDivWithSubPass(PassBase): |
| def call(self, gm) -> PassResult: |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.div: |
| node.target = torch.sub |
| |
| |
| # Pass that is a callable and does not return a PassResult |
| # Need to wrap with pass_result_wrapper or else it will fail |
| def replace_sub_with_add_pass(gm) -> PassResult: |
| for node in gm.graph.nodes: |
| if node.op == "call_function" and node.target == torch.sub: |
| node.target = torch.add |
| |
| |
| class AddModule(torch.nn.Module): |
| def forward(self, x): |
| y = torch.add(x, x) |
| z = torch.add(y, x) |
| return z |
| |
| |
| class TestPassManager(TestCase): |
| def test_pass_manager(self): |
| """ |
| Tests that the pass manager runs the passes correctly. |
| """ |
| |
| m = AddModule() |
| traced_m = torch.fx.symbolic_trace(m) |
| pm = PassManager( |
| passes=[ |
| ReplaceAddWithMulPass(), |
| replace_mul_with_div_pass, |
| pass_result_wrapper(ReplaceDivWithSubPass()), |
| pass_result_wrapper(replace_sub_with_add_pass), |
| ], |
| steps=5, |
| ) |
| |
| pm.validate_constraints() |
| self.assertEqual(len(pm.passes), 4) |
| |
| res = pm(traced_m) |
| modified_m = res.graph_module |
| assert isinstance(modified_m, fx.GraphModule) |
| |
| # Check that all call_function nodes are divs |
| for node in modified_m.graph.nodes: |
| if node.op == "call_function": |
| self.assertEqual(node.target, torch.add) |
| |
| def test_this_before_that_pass_constraint(self): |
| """ |
| Tests the construction of constraints |
| """ |
| passes = [lambda x: 2 * x for _ in range(10)] |
| pm = PassManager(passes) |
| |
| # add unfulfillable constraint |
| pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) |
| |
| with self.assertRaises(RuntimeError): |
| pm.validate_constraints() |
| |
| def test_pass_manager_checks(self): |
| """ |
| Tests that users can add in check functions correctly |
| """ |
| m = AddModule() |
| traced_m = fx.symbolic_trace(m) |
| pm = PassManager(passes=[ReplaceAddWithMulPass(), replace_mul_with_div_pass]) |
| |
| def check_div_target(graph_module): |
| for node in graph_module.graph.nodes: |
| if node.op == "call_function" and node.target != torch.div: |
| raise ValueError("Target should be div!") |
| |
| pm.add_checks(check_div_target) |
| |
| with self.assertRaises(ValueError): |
| pm(traced_m) |
| |
| def test_pass_manager_bad_checks(self): |
| """ |
| Checks that we error if we pass in a check function with the wrong parameters |
| """ |
| |
| def check_bad_args(graph_module, i): |
| pass |
| |
| pm = PassManager() |
| self.assertRaises(TypeError, pm.add_checks, check_bad_args) |
| |
| def test_topological_sort(self): |
| """ |
| Tests that passes are correctly ordered based on contraints. |
| """ |
| |
| def pass0(x): |
| return x |
| |
| def pass1(x): |
| return x + 1 |
| |
| def pass2(x): |
| return x + 2 |
| |
| def pass3(x): |
| return x + 3 |
| |
| def pass4(x): |
| return x + 4 |
| |
| def pass5(x): |
| return x + 5 |
| |
| # Not passing any constraints should keep the original order |
| passes = [pass0, pass1, pass2, pass3, pass4, pass5] |
| sorted = _topological_sort_passes(passes, []) |
| self.assertEqual(sorted, passes) |
| |
| # Graph that we are constructing: |
| # 5 ----> 0 <---- 4 |
| # | | |
| # +-> 2 -> 3 -> 1 <-+ |
| # Which has a possible topological order of: [4, 5, 0, 2, 3, 1] |
| passes = [pass0, pass1, pass2, pass3, pass4, pass5] |
| constraints = [ |
| this_before_that_pass_constraint(pass5, pass0), |
| this_before_that_pass_constraint(pass5, pass2), |
| this_before_that_pass_constraint(pass4, pass0), |
| this_before_that_pass_constraint(pass4, pass1), |
| this_before_that_pass_constraint(pass2, pass3), |
| this_before_that_pass_constraint(pass3, pass1), |
| ] |
| sorted = _topological_sort_passes(passes, constraints) |
| self.assertEqual(sorted, [pass4, pass5, pass0, pass2, pass3, pass1]) |
| |
| # Circular dependency should result in the circular_dep flag being set |
| passes = [pass0, pass1, pass2] |
| constraints = [ |
| this_before_that_pass_constraint(passes[0], passes[1]), |
| this_before_that_pass_constraint(passes[1], passes[2]), |
| this_before_that_pass_constraint(passes[2], passes[0]), |
| ] |
| with self.assertRaises(RuntimeError) as e: |
| _topological_sort_passes(passes, constraints) |
| expected_error_msg = ( |
| f"Circular dependency detected within the following passes: {passes}" |
| ) |
| self.assertEqual(e.exception.args[0], expected_error_msg) |
| |
| def test_pass_manager_error(self): |
| """ |
| Tests error catching + debug |
| """ |
| |
| def pass_fail(graph_module): |
| raise RuntimeError("bad") |
| |
| m = AddModule() |
| traced_m = torch.fx.symbolic_trace(m) |
| pm = PassManager( |
| passes=[ |
| ReplaceAddWithMulPass(), |
| replace_mul_with_div_pass, |
| ReplaceDivWithSubPass(), |
| pass_result_wrapper(replace_sub_with_add_pass), |
| ], |
| ) |
| |
| # Comment out this line to see the actual error message |
| error_msg = ( |
| "ReplaceDivWithSubPass.*ReplaceAddWithMulPass.*replace_mul_with_div_pass" |
| ) |
| with self.assertRaisesRegex(Exception, error_msg): |
| pm(traced_m) |
| |
| pm = PassManager( |
| passes=[ |
| ReplaceAddWithMulPass(), |
| replace_mul_with_div_pass, |
| pass_result_wrapper(ReplaceDivWithSubPass()), |
| pass_result_wrapper(replace_sub_with_add_pass), |
| pass_fail, |
| ], |
| ) |
| |
| # Comment out this line to see the actual error message |
| error_msg = "pass_fail.*ReplaceAddWithMulPass.*replace_mul_with_div_pass.*ReplaceDivWithSubPass.*replace_sub_with_add_pass" |
| with self.assertRaisesRegex(Exception, error_msg): |
| pm(traced_m) |