explicitly reset stderr/stdout in precompilation (#125289)

I was seeing a weird bug where after running max-autotune my stdout would be misdirected. other people have not been able to repro this.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125289
Approved by: https://github.com/shunting314, https://github.com/mlazos
diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py
index 577a1c3..8fcb441 100644
--- a/torch/_inductor/select_algorithm.py
+++ b/torch/_inductor/select_algorithm.py
@@ -40,6 +40,7 @@
 from .utils import (
     get_dtype_size,
     Placeholder,
+    restore_stdout_stderr,
     sympy_dot,
     sympy_index_symbol,
     sympy_product,
@@ -1007,15 +1008,28 @@
                 num_workers,
             )
 
+            # In rare circumstances, because python threads inherit global state,
+            # thread pool executor can race and leave stdout/stderr in a state
+            # different than the original values. we explicitly restore the state
+            # here to avoid this issue.
+
+            initial_stdout = sys.stdout
+            initial_stderr = sys.stderr
+
+            def precompile_with_captured_stdout(choice):
+                with restore_stdout_stderr(initial_stdout, initial_stderr):
+                    return choice.precompile()
+
             executor = ThreadPoolExecutor(max_workers=num_workers)
             futures = executor.map(
-                lambda c: c.precompile(),
+                lambda c: precompile_with_captured_stdout(c),
                 [c for c in choices if hasattr(c, "precompile")],
                 timeout=precompilation_timeout_seconds,
             )
             from triton.runtime.autotuner import OutOfResources
 
             @functools.lru_cache(None)
+            @restore_stdout_stderr(initial_stdout, initial_stderr)
             def wait_on_futures():
                 counters["inductor"]["select_algorithm_precompile"] += 1
                 try:
diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py
index 27bf361..d4063b5 100644
--- a/torch/_inductor/utils.py
+++ b/torch/_inductor/utils.py
@@ -848,6 +848,15 @@
         return res
 
 
+@contextlib.contextmanager
+def restore_stdout_stderr(initial_stdout, initial_stderr):
+    try:
+        yield
+    finally:
+        sys.stdout = initial_stdout
+        sys.stderr = initial_stderr
+
+
 class DeferredLineBase:
     """A line that can be 'unwritten' at a later time"""