Move module_tracker to logging for confused hierarchy (#134467) (#134501)
* Move module_tracker to logging for confused hierarchy (#134467)
Fixes https://github.com/pytorch/pytorch/issues/134242
Make sure to never raise an error when confused. Logs for confusion can be enabled with `TORCH_LOGS="torch.utils.module_tracker"` or the usual python systems.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134467
Approved by: https://github.com/malfet
* Fix bad merge conflict resolution
diff --git a/test/test_module_tracker.py b/test/test_module_tracker.py
index 450a787..457b964 100644
--- a/test/test_module_tracker.py
+++ b/test/test_module_tracker.py
@@ -3,7 +3,9 @@
from copy import copy
import torch
+from torch import nn
from torch.testing._internal.common_utils import run_tests, TestCase, xfailIfTorchDynamo
+from torch.utils.checkpoint import checkpoint
from torch.utils.module_tracker import ModuleTracker
@@ -14,7 +16,7 @@
seen_fw = []
seen_bw = []
- class Foo(torch.nn.Module):
+ class Foo(nn.Module):
def forward(self, x):
x = x["a"].relu_()
seen_fw.append((copy(tracker.parents), tracker.is_bw))
@@ -23,12 +25,12 @@
)
return {"a": torch.mm(x, x)}
- class Mod(torch.nn.Module):
- def __init__(self):
+ class Mod(nn.Module):
+ def __init__(self) -> None:
super().__init__()
self.a = Foo()
- self.b = torch.nn.ModuleDict({"nest": Foo()})
- self.c = torch.nn.ModuleList([Foo()])
+ self.b = nn.ModuleDict({"nest": Foo()})
+ self.c = nn.ModuleList([Foo()])
def forward(self, x):
x = self.c[0](x)
@@ -68,8 +70,36 @@
],
)
+ def test_confused_hierarchy(self):
+ class MyMod(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.inner = nn.Linear(2, 2)
+ self.ran = False
+
+ def forward(self, inp):
+ if not self.ran:
+ self.ran = True
+ return self(inp)
+ else:
+ self.ran = False
+ return self.inner(inp)
+
+ mod = MyMod()
+ inp = torch.rand(1, 2, requires_grad=True)
+
+ # Should not fail
+ with ModuleTracker() as tracker:
+ res = mod(inp)
+ res.sum().backward()
+
+ # Should not fail
+ with ModuleTracker() as tracker:
+ res = checkpoint(lambda inp: mod(inp), inp)
+ res.sum().backward()
+
def test_bw_detection(self):
- mod = torch.nn.Linear(2, 2)
+ mod = nn.Linear(2, 2)
with ModuleTracker() as tracker:
mod(torch.rand(2, requires_grad=True)).sum().backward()
diff --git a/torch/autograd/graph.py b/torch/autograd/graph.py
index cde56a6..c41eee5 100644
--- a/torch/autograd/graph.py
+++ b/torch/autograd/graph.py
@@ -158,7 +158,8 @@
def _get_grad_fn_or_grad_acc(t):
if t.requires_grad and t.grad_fn is None:
- return t.view_as(t).grad_fn.next_functions[0][0]
+ with torch.enable_grad():
+ return t.view_as(t).grad_fn.next_functions[0][0]
else:
return t.grad_fn
diff --git a/torch/utils/module_tracker.py b/torch/utils/module_tracker.py
index 9feef40..0e9bfaa 100644
--- a/torch/utils/module_tracker.py
+++ b/torch/utils/module_tracker.py
@@ -1,4 +1,5 @@
# mypy: allow-untyped-defs
+import logging
import weakref
from typing import Set
@@ -11,6 +12,10 @@
)
from torch.utils._pytree import tree_flatten
+
+logger = logging.getLogger(__name__)
+
+
__all__ = ["ModuleTracker"]
@@ -93,9 +98,10 @@
if is_bw:
self._maybe_set_engine_callback()
if name in self.parents:
- print(
- "The module hierarchy tracking seems to be messed up."
- "Please file a bug to PyTorch."
+ logger.info(
+ "The module hierarchy tracking seems to be broken as this Module was already entered. %s during %s",
+ name,
+ "backward" if is_bw else "forward",
)
self.parents.add(name)
@@ -105,11 +111,11 @@
def fn(*args):
if name in self.parents:
self.parents.remove(name)
- elif not is_bw:
- # Due to some input/output not requiring gradients, we cannot enforce
- # proper nesting in backward
- raise RuntimeError(
- "The Module hierarchy tracking is wrong. Report a bug to PyTorch"
+ else:
+ logger.info(
+ "The Module hierarchy tracking is confused as we're exiting a Module that was never entered. %s during %s",
+ name,
+ "backward" if is_bw else "forward",
)
return fn