| # Owner(s): ["module: fx"] |
| import torch |
| from torch.fx import symbolic_trace |
| from torch.testing._internal.common_utils import TestCase |
| |
| |
| class TestFXNodeHook(TestCase): |
| def test_hooks_for_node_update(self): |
| global create_node_hook1_called |
| global create_node_hook2_called |
| global erase_node_hook1_called |
| global erase_node_hook2_called |
| create_node_hook1_called = False |
| create_node_hook2_called = False |
| erase_node_hook1_called = False |
| erase_node_hook2_called = False |
| |
| def fn(a, b, c): |
| x = torch.nn.functional.linear(a, b) |
| x = x + c |
| return x.cos() |
| |
| def create_node_hook1(node): |
| global create_node_hook1_called |
| create_node_hook1_called = True |
| |
| def create_node_hook2(node): |
| global create_node_hook2_called |
| create_node_hook2_called = True |
| |
| def erase_node_hook1(node): |
| global erase_node_hook1_called |
| erase_node_hook1_called = True |
| |
| def erase_node_hook2(node): |
| global erase_node_hook2_called |
| erase_node_hook2_called = True |
| |
| gm = symbolic_trace(fn) |
| gm._register_create_node_hook(create_node_hook1) |
| gm._register_create_node_hook(create_node_hook2) |
| gm._register_erase_node_hook(erase_node_hook1) |
| gm._register_erase_node_hook(erase_node_hook2) |
| |
| graph = gm.graph |
| node_a = None |
| for node in graph.find_nodes(op="placeholder"): |
| node_a = node |
| break |
| assert node_a is not None |
| # This will create a new node |
| node_a_copy = graph.node_copy(node_a) |
| node_a.replace_all_uses_with(node_a_copy) |
| graph.erase_node(node_a) |
| |
| assert ( |
| create_node_hook1_called |
| and create_node_hook2_called |
| and erase_node_hook1_called |
| and erase_node_hook2_called |
| ) |
| |
| gm._unregister_create_node_hook(create_node_hook1) |
| gm._unregister_create_node_hook(create_node_hook2) |
| gm._unregister_erase_node_hook(erase_node_hook1) |
| gm._unregister_erase_node_hook(erase_node_hook2) |
| |
| assert gm._create_node_hooks == [] |
| assert gm._erase_node_hooks == [] |