[inductor] Fix bug where a node gets erased twice (#100848)
Fixes #100806
The underlying bug is if you erase an FX node twice, everything runs without error, but `len(graph.nodes)` reports the incorrect value.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100848
Approved by: https://github.com/ngimel, https://github.com/Skylion007
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index 1d0d157..38bf2a1 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -670,6 +670,36 @@
with torch.no_grad():
self.common(mod, (torch.randn(4, 4),))
+ def test_issue100806(self):
+ class Model(torch.nn.Module):
+ def __init__(self):
+ super(Model, self).__init__()
+ self.linear1 = torch.nn.Linear(10, 20)
+ self.linear2 = torch.nn.Linear(20, 30)
+ self.relu = torch.nn.ReLU()
+
+ def forward(self, x):
+ x = self.linear1(x)
+ x = self.linear2(x)
+ x = torch.cat((x, x), dim=1)
+ x = x.view(-1, 2, 30)
+ x = x[:, 1, :]
+ x = self.relu(x)
+ return x
+
+ device = "cuda"
+ batch_size = 2
+ x = torch.randn(batch_size, 10).to(device)
+ func = Model().to(device)
+
+ with torch.no_grad():
+ func.train(False)
+ jit_func = torch.compile(func)
+
+ res1 = func(x)
+ res2 = jit_func(x)
+ self.assertEqual(res1, res2)
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index c4359bf..94239f4 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -76,7 +76,8 @@
def erase_nodes(self, graph: torch.fx.Graph):
for n in reversed(self.nodes):
- graph.erase_node(n)
+ if not n._erased:
+ graph.erase_node(n)
def output_nodes(self):
return [
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index ba8d579..497d434 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -867,6 +867,9 @@
if len(to_erase.users) > 0:
raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
f'users in the graph: {to_erase.users}!')
+ if to_erase._erased:
+ warnings.warn(f"erase_node({to_erase}) on an already erased node")
+ return
to_erase._remove_from_list()
to_erase._erased = True # iterators may retain handles to erased nodes