[compiled autograd] Fix LoggingTensor flaky test (#126144)
LoggingTensor fails consistently when root logger level is INFO or lower
By default, root logger should be WARNING
But, triton driver initialization will overwrite root logger to INFO, which causes flakiness: https://github.com/pytorch/pytorch/issues/126143
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126144
Approved by: https://github.com/jansel
diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py
index 074d075..201dd4a 100644
--- a/test/inductor/test_compiled_autograd.py
+++ b/test/inductor/test_compiled_autograd.py
@@ -1,5 +1,6 @@
# Owner(s): ["module: inductor"]
import functools
+import logging
import re
import sys
import unittest
@@ -51,6 +52,14 @@
class TestCompiledAutograd(TestCase):
+ def setUp(self) -> None:
+ super().setUp()
+ compiled_autograd.reset()
+
+ def tearDown(self) -> None:
+ super().tearDown()
+ compiled_autograd.reset()
+
def check_output_and_recompiles(
self, fn, count=1, compiler_fn=compiler_fn, compile_fn=False
):
@@ -322,6 +331,7 @@
handle.remove()
def test_inputs_aliasing_bytecode_stack_restore(self):
+ logging.getLogger().setLevel(logging.WARNING)
from torch.testing._internal.logging_tensor import LoggingTensor
# Create a graph that allows inputs stealing
@@ -753,6 +763,52 @@
self.check_output_and_recompiles(fn, count=2)
@unittest.skipIf(not HAS_CUDA, "requires cuda")
+ def test_logging_tensor_flaky(self) -> None:
+ # when you first run some test using triton and then run test_inputs_aliasing_bytecode_stack_restore
+ # resulting in:
+ # - pytest: `TypeError: unsupported operand type(s) for +: 'Tensor' and 'LoggingTensor'`
+ # - python: `TypeError: not all arguments converted during string formatting`
+
+ # 1. some triton involving test
+ def fn():
+ def _fn(x):
+ return x
+
+ x = torch.arange(
+ 1, 10, requires_grad=True, dtype=torch.float16, device="cuda"
+ )
+ out = _fn(x)
+ loss = out.sum()
+ loss.backward()
+
+ with compiled_autograd.enable(compiler_fn):
+ fn()
+
+ logging.getLogger().setLevel(
+ logging.WARNING
+ ) # triton setup overwrote it to INFO
+ # 2. test_inputs_aliasing_bytecode_stack_restore
+ from torch.testing._internal.logging_tensor import LoggingTensor
+
+ def forward(inputs):
+ add = inputs[0] + 1
+ add_1 = add + inputs[1]
+ out = add_1.cpu()
+ return (out,)
+
+ gm = torch.fx.symbolic_trace(forward)
+ print(gm.print_readable())
+ torch._dynamo.utils.set_locals_to_steal(gm, ["inputs"])
+ compiled_fn = torch.compile(gm)
+
+ inputs = [
+ torch.ones(1000000, dtype=torch.float32),
+ LoggingTensor(torch.ones(1)),
+ ]
+
+ compiled_fn(inputs)
+
+ @unittest.skipIf(not HAS_CUDA, "requires cuda")
def test_custom_fn_output_metadata(self):
def my_compiler_fn(gm):
for node in gm.graph.nodes:
diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py
index e8e6104..386d0b4 100644
--- a/torch/_dynamo/compiled_autograd.py
+++ b/torch/_dynamo/compiled_autograd.py
@@ -319,3 +319,10 @@
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
+
+
+# return to starting state of a new process
+def reset() -> None:
+ compiled_autograd_enable = False
+ assert compiled_autograd_enabled_count == 0
+ torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
diff --git a/torch/testing/_internal/logging_tensor.py b/torch/testing/_internal/logging_tensor.py
index 5ddd537..8b7faf4 100644
--- a/torch/testing/_internal/logging_tensor.py
+++ b/torch/testing/_internal/logging_tensor.py
@@ -11,6 +11,7 @@
import functools
from torch._C._profiler import gather_traceback, symbolize_tracebacks
+logger = logging.getLogger("LoggingTensor")
_dtype_abbrs = {
torch.bfloat16: "bf16",
@@ -135,8 +136,8 @@
if self.tracebacks_list is not None:
self.tracebacks_list.append(record.traceback)
-def log_input(name: str, var: object):
- logging.getLogger("LoggingTensor").info("input", (name,), {}, var) # noqa: PLE1205
+def log_input(name: str, var: object) -> None:
+ logger.info("input", (name,), {}, var) # noqa: PLE1205
class GatherTraceback(logging.Filter):
def __init__(self, python=True, script=True, cpp=False):
@@ -151,7 +152,6 @@
@contextlib.contextmanager
def capture_logs(is_mode=False, python_tb=False, script_tb=False, cpp_tb=False) -> Iterator[List[str]]:
collect_traceback = python_tb or script_tb or cpp_tb
- logger = logging.getLogger("LoggingTensor")
log_list: List[str] = []
tracebacks_list: List[str] = []
handler = LoggingTensorHandler(