[FX] Fix uses not updating when erasing a node (#47720)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47720
Test Plan: Imported from OSS
Reviewed By: zdevito
Differential Revision: D24875880
Pulled By: jamesr66a
fbshipit-source-id: aae9ffd10f8085b599e7923152287c6e6950ff49
diff --git a/test/test_fx.py b/test/test_fx.py
index dcb1045..b035b37 100644
--- a/test/test_fx.py
+++ b/test/test_fx.py
@@ -700,6 +700,19 @@
ref = torch.sin(mod.linear(input) + mod.bias)
self.assertEqual(r, ref)
+ def test_remove_uses(self):
+ g : torch.fx.Graph = Graph()
+ x : torch.fx.Node = g.placeholder('x')
+ relu : torch.fx.Node = g.call_function(torch.relu, (x,))
+ neg : torch.fx.Node = g.call_function(torch.neg, (relu,))
+ g.output(neg)
+
+ neg.replace_all_uses_with(relu)
+ g.erase_node(neg)
+
+ self.assertTrue(neg not in relu.users)
+
+
def test_construct_root_dict(self):
graph : torch.fx.Graph = torch.fx.Graph()
a : torch.fx.Node = graph.create_node('placeholder', 'x')
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index 45e5184..6543847 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -261,6 +261,15 @@
to_erase._erased = True # iterators may retain handles to erased nodes
self._len -= 1
+ # Null out this Node's argument nodes so that the Nodes referred to
+ # can update their `users` accordingly
+ new_args = map_arg(to_erase.args, lambda n: None)
+ assert isinstance(new_args, tuple)
+ to_erase.args = new_args
+ new_kwargs = map_arg(to_erase.kwargs, lambda n: None)
+ assert isinstance(new_kwargs, dict)
+ to_erase.kwargs = new_kwargs
+
def inserting_before(self, n: Optional[Node] = None):
"""Set the point at which create_node and companion methods will insert into the graph.
When used within a 'with' statement, this will temporary set the insert point and