[inductor] Fix bug where a node gets erased twice (#100848)

Fixes #100806

The underlying bug is if you erase an FX node twice, everything runs without error, but `len(graph.nodes)` reports the incorrect value.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100848
Approved by: https://github.com/ngimel, https://github.com/Skylion007
diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py
index 1d0d157..38bf2a1 100644
--- a/test/inductor/test_cuda_repro.py
+++ b/test/inductor/test_cuda_repro.py
@@ -670,6 +670,36 @@
         with torch.no_grad():
             self.common(mod, (torch.randn(4, 4),))
 
+    def test_issue100806(self):
+        class Model(torch.nn.Module):
+            def __init__(self):
+                super(Model, self).__init__()
+                self.linear1 = torch.nn.Linear(10, 20)
+                self.linear2 = torch.nn.Linear(20, 30)
+                self.relu = torch.nn.ReLU()
+
+            def forward(self, x):
+                x = self.linear1(x)
+                x = self.linear2(x)
+                x = torch.cat((x, x), dim=1)
+                x = x.view(-1, 2, 30)
+                x = x[:, 1, :]
+                x = self.relu(x)
+                return x
+
+        device = "cuda"
+        batch_size = 2
+        x = torch.randn(batch_size, 10).to(device)
+        func = Model().to(device)
+
+        with torch.no_grad():
+            func.train(False)
+            jit_func = torch.compile(func)
+
+            res1 = func(x)
+            res2 = jit_func(x)
+            self.assertEqual(res1, res2)
+
 
 if __name__ == "__main__":
     from torch._dynamo.test_case import run_tests
diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py
index c4359bf..94239f4 100644
--- a/torch/_inductor/pattern_matcher.py
+++ b/torch/_inductor/pattern_matcher.py
@@ -76,7 +76,8 @@
 
     def erase_nodes(self, graph: torch.fx.Graph):
         for n in reversed(self.nodes):
-            graph.erase_node(n)
+            if not n._erased:
+                graph.erase_node(n)
 
     def output_nodes(self):
         return [
diff --git a/torch/fx/graph.py b/torch/fx/graph.py
index ba8d579..497d434 100644
--- a/torch/fx/graph.py
+++ b/torch/fx/graph.py
@@ -867,6 +867,9 @@
         if len(to_erase.users) > 0:
             raise RuntimeError(f'Tried to erase Node {to_erase} but it still had {len(to_erase.users)} '
                                f'users in the graph: {to_erase.users}!')
+        if to_erase._erased:
+            warnings.warn(f"erase_node({to_erase}) on an already erased node")
+            return
 
         to_erase._remove_from_list()
         to_erase._erased = True  # iterators may retain handles to erased nodes