Move module_tracker to logging for confused hierarchy (#134467) (#134501)

* Move module_tracker to logging for confused hierarchy (#134467)

Fixes https://github.com/pytorch/pytorch/issues/134242

Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134467
Approved by: https://github.com/malfet

* Fix bad merge conflict resolution
diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py
index 450a787..457b964 100644
--- a/test/test_module_tracker.py
+++ b/test/test_module_tracker.py
@@ -3,7 +3,9 @@
 from copy import copy
 
 import torch
+from torch import nn
 from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
+from torch.utils.checkpoint import checkpoint
 from torch.utils.module_tracker import ModuleTracker
 
 
@@ -14,7 +16,7 @@
         seen_fw = []
         seen_bw = []
 
-        class Foo(torch.nn.Module):
+        class Foo(nn.Module):
             def forward(self, x):
                 x = x["a"].relu_()
                 seen_fw.append((copy(tracker.parents), tracker.is_bw))
@@ -23,12 +25,12 @@
                 )
                 return {"a": torch.mm(x, x)}
 
-        class Mod(torch.nn.Module):
-            def __init__(self):
+        class Mod(nn.Module):
+            def __init__(self) -> None:
                 super().__init__()
                 self.a = Foo()
-                self.b = torch.nn.ModuleDict({"nest": Foo()})
-                self.c = torch.nn.ModuleList([Foo()])
+                self.b = nn.ModuleDict({"nest": Foo()})
+                self.c = nn.ModuleList([Foo()])
 
             def forward(self, x):
                 x = self.c[0](x)
@@ -68,8 +70,36 @@
             ],
         )
 
+    def test_confused_hierarchy(self):
+        class MyMod(nn.Module):
+            def __init__(self):
+                super().__init__()
+                self.inner = nn.Linear(2, 2)
+                self.ran = False
+
+            def forward(self, inp):
+                if not self.ran:
+                    self.ran = True
+                    return self(inp)
+                else:
+                    self.ran = False
+                    return self.inner(inp)
+
+        mod = MyMod()
+        inp = torch.rand(1, 2, requires_grad=True)
+
+        # Should not fail
+        with ModuleTracker() as tracker:
+            res = mod(inp)
+            res.sum().backward()
+
+        # Should not fail
+        with ModuleTracker() as tracker:
+            res = checkpoint(lambda inp: mod(inp), inp)
+            res.sum().backward()
+
     def test_bw_detection(self):
-        mod = torch.nn.Linear(2, 2)
+        mod = nn.Linear(2, 2)
 
         with ModuleTracker() as tracker:
             mod(torch.rand(2, requires_grad=True)).sum().backward()
diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py
index cde56a6..c41eee5 100644
--- a/torch/autograd/graph.py
+++ b/torch/autograd/graph.py
@@ -158,7 +158,8 @@
 
 def _get_grad_fn_or_grad_acc(t):
     if t.requires_grad and t.grad_fn is None:
-        return t.view_as(t).grad_fn.next_functions[0][0]
+        with torch.enable_grad():
+            return t.view_as(t).grad_fn.next_functions[0][0]
     else:
         return t.grad_fn
 
diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py
index 9feef40..0e9bfaa 100644
--- a/torch/utils/module_tracker.py
+++ b/torch/utils/module_tracker.py
@@ -1,4 +1,5 @@
 # mypy: allow-untyped-defs
+import logging
 import weakref
 
 from typing import Set
@@ -11,6 +12,10 @@
 )
 from torch.utils._pytree import tree_flatten
 
+
+logger = logging.getLogger(__name__)
+
+
 __all__ = ["ModuleTracker"]
 
 
@@ -93,9 +98,10 @@
             if is_bw:
                 self._maybe_set_engine_callback()
             if name in self.parents:
-                print(
-                    "The module hierarchy tracking seems to be messed up."
-                    "Please file a bug to PyTorch."
+                logger.info(
+                    "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
+                    name,
+                    "backward" if is_bw else "forward",
                 )
             self.parents.add(name)
 
@@ -105,11 +111,11 @@
         def fn(*args):
             if name in self.parents:
                 self.parents.remove(name)
-            elif not is_bw:
-                # Due to some input/output not requiring gradients, we cannot enforce
-                # proper nesting in backward
-                raise RuntimeError(
-                    "The Module hierarchy tracking is wrong. Report a bug to PyTorch"
+            else:
+                logger.info(
+                    "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
+                    name,
+                    "backward" if is_bw else "forward",
                 )
 
         return fn