[inductor] parallel compile: Create new pipes for subproc communication (#131194)
Summary: Rather then using stdin/stdout for IPC, we can create new pipes and pass the descriptors to the subproc via the cmd line. https://github.com/pytorch/pytorch/issues/131070 reports an issue where the combination of deepspeed and onnxruntime-training causes _something_ in the subproc to write to stdout and corrupt the IPC. The current implementation was already brittle; we can just create new pipes specifically for the IPC.
Test Plan: I was able to repro the MemoryError in https://github.com/pytorch/pytorch/issues/131070 by installing deepspeed and onnxruntime-training. Verified this PR fixes.
Differential Revision: [D59968362](https://our.internmc.facebook.com/intern/diff/D59968362)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131194
Approved by: https://github.com/malfet, https://github.com/eellison, https://github.com/atalman
diff --git a/torch/_inductor/compile_worker/__main__.py b/torch/_inductor/compile_worker/__main__.py
index a343dc6..547ad5f 100644
--- a/torch/_inductor/compile_worker/__main__.py
+++ b/torch/_inductor/compile_worker/__main__.py
@@ -3,10 +3,9 @@
import logging
import os
import sys
-import typing
from torch._inductor.async_compile import pre_fork_setup
-from torch._inductor.compile_worker.subproc_pool import Pipe, SubprocMain
+from torch._inductor.compile_worker.subproc_pool import SubprocMain
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
from torch._inductor.runtime.compile_tasks import _set_triton_ptxas_path
@@ -27,17 +26,13 @@
parser = argparse.ArgumentParser()
parser.add_argument("--workers", type=int)
parser.add_argument("--parent", type=int)
+ parser.add_argument("--read-fd", type=int)
+ parser.add_argument("--write-fd", type=int)
args = parser.parse_args()
if os.getppid() != args.parent:
sys.exit(0)
- write_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdout.fileno()), "wb"))
- read_fd = typing.cast(Pipe, os.fdopen(os.dup(sys.stdin.fileno()), "rb"))
-
- # nobody else should read stdin
- sys.stdin.close()
-
- # redirect output of workers to stderr
- os.dup2(sys.stderr.fileno(), sys.stdout.fileno())
+ read_fd = os.fdopen(args.read_fd, "rb")
+ write_fd = os.fdopen(args.write_fd, "wb")
pre_fork_setup()
diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py
index ce3da87..72f9426 100644
--- a/torch/_inductor/compile_worker/subproc_pool.py
+++ b/torch/_inductor/compile_worker/subproc_pool.py
@@ -21,20 +21,6 @@
log = logging.getLogger(__name__)
-class Pipe(typing.Protocol):
- def write(self, data: bytes):
- ...
-
- def read(self, n: int) -> bytes:
- ...
-
- def close(self):
- ...
-
- def flush(self):
- ...
-
-
def _pack_msg(job_id, length):
return struct.pack("nn", job_id, length)
@@ -103,16 +89,22 @@
def __init__(self, nprocs: int):
entry = os.path.join(os.path.dirname(__file__), "__main__.py")
+
+ subproc_read_fd, write_fd = os.pipe()
+ read_fd, subproc_write_fd = os.pipe()
+ self.write_pipe = os.fdopen(write_fd, "wb")
+ self.read_pipe = os.fdopen(read_fd, "rb")
+
cmd = [
sys.executable,
entry,
f"--workers={nprocs}",
f"--parent={os.getpid()}",
+ f"--read-fd={str(subproc_read_fd)}",
+ f"--write-fd={str(subproc_write_fd)}",
]
self.process = subprocess.Popen(
cmd,
- stdin=subprocess.PIPE,
- stdout=subprocess.PIPE,
env={
**os.environ,
# We need to set the PYTHONPATH so the subprocess can find torch.
@@ -124,10 +116,9 @@
# Some internal usages need a modified LD_LIBRARY_PATH.
"LD_LIBRARY_PATH": _get_ld_library_path(),
},
+ pass_fds=(subproc_read_fd, subproc_write_fd),
)
- self.write_pipe: Pipe = typing.cast(Pipe, self.process.stdin)
self.write_lock = threading.Lock()
- self.read_pipe: Pipe = typing.cast(Pipe, self.process.stdout)
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
self.futures_lock = threading.Lock()
@@ -204,7 +195,7 @@
class SubprocMain:
"""Communicates with a SubprocPool in the parent process, called by __main__.py"""
- def __init__(self, nprocs: int, read_pipe: Pipe, write_pipe: Pipe):
+ def __init__(self, nprocs, read_pipe, write_pipe):
self.read_pipe = read_pipe
self.write_pipe = write_pipe
self.write_lock = threading.Lock()