Prevent infinite recursion within Tensor.__repr__ (#120206)
`Tensor.__repr__` calls functions which can perform logging which ends up logging `self` (with `__repr__`) causing an infinite loop. Instead of logging all the args in FakeTensor.dispatch log the actual parameters (and use `id` to log the tensor itself).
The change to torch/testing/_internal/common_utils.py came up during testing - in some ways of running the test parts was `('test', 'test_testing.py')` and so `i` was 0 and we were doing a join on `()` which was causing an error.
Repro:
```
import torch
from torch.testing import make_tensor
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
t = torch.sparse_coo_tensor(((0, 1), (1, 0)), (1, 2), size=(2, 2))
t2 = FakeTensor.from_tensor(t, FakeTensorMode())
print(repr(t2))
```
and run with `TORCH_LOGS=+all`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120206
Approved by: https://github.com/yanboliang, https://github.com/pearu
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 9871c42..2f525d6 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -1219,7 +1219,8 @@
def dispatch(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
- log.debug("%s %s %s", func, args, kwargs)
+ with no_dispatch():
+ log.debug("%s %s %s", func, args, kwargs)
if func in _DISPATCH_META_HANDLERS:
return _DISPATCH_META_HANDLERS[func](args)
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index eb2abb9..64eeb6c 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -2654,7 +2654,7 @@
parts = Path(abs_test_path).parts
for i, part in enumerate(parts):
if part == "test":
- base_dir = os.path.join(*parts[:i])
+ base_dir = os.path.join(*parts[:i]) if i > 0 else ''
return os.path.relpath(abs_test_path, start=base_dir)
# Can't determine containing dir; just return the test filename.