[pipelining] Add pipe.build_stage() (#128240)

Given `PipelineStage` name to manual side.
Thus adding a method under `Pipe` to create PipelineStage.
Moved `PipeInfo` to utils.py to avoid circular dependency between `_IR` and `PipelineStage`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128240
Approved by: https://github.com/wconstab, https://github.com/H-Huang
diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst
index 05cefb2..32efef6 100644
--- a/docs/source/distributed.pipelining.rst
+++ b/docs/source/distributed.pipelining.rst
@@ -179,8 +179,7 @@
 
 .. code-block:: python
 
-  from torch.distributed.pipelining import TracerPipelineStage
-  stage = TracerPipelineStage(pipe, stage_idx, device)
+  stage = pipe.build_stage(stage_idx, device, group)
 
 .. note::
   The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your
@@ -354,8 +353,6 @@
 
 .. autoclass:: PipelineStage
 
-.. autoclass:: TracerPipelineStage
-
 Pipeline Schedules
 ==================
 
diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py
index ef9a636..22d1167 100644
--- a/test/distributed/pipelining/test_schedule.py
+++ b/test/distributed/pipelining/test_schedule.py
@@ -19,7 +19,6 @@
     ScheduleGPipe,
     ScheduleInterleaved1F1B,
     ScheduleLoopedBFS,
-    TracerPipelineStage,
 )
 from torch.distributed.pipelining.PipelineSchedule import _Action, _ComputationType
 from torch.distributed.pipelining.PipelineStage import _PipelineStageBase
@@ -98,11 +97,9 @@
             split_spec=split_spec,
         )
 
-        stage = TracerPipelineStage(
-            pipe,
+        stage = pipe.build_stage(
             self.rank,
             self.device,
-            chunks,  # to be cleaned
         )
 
         # Attach to a schedule
@@ -140,11 +137,9 @@
             mb_kwargs={"y": y_mb},
         )
 
-        stage = TracerPipelineStage(
-            pipe,
+        stage = pipe.build_stage(
             self.rank,
             self.device,
-            chunks,  # to be cleaned
         )
 
         # Attach to a schedule
@@ -203,11 +198,9 @@
             split_spec=split_spec,
         )
 
-        stage = TracerPipelineStage(
-            pipe,
+        stage = pipe.build_stage(
             self.rank,
             self.device,
-            chunks,  # to be cleaned
         )
 
         # Attach to a schedule
diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py
index b11a603..959dd2e 100644
--- a/test/distributed/pipelining/test_stage.py
+++ b/test/distributed/pipelining/test_stage.py
@@ -8,12 +8,7 @@
 
 import torch
 import torch.distributed as dist
-from torch.distributed.pipelining import (
-    pipeline,
-    PipelineStage,
-    ScheduleGPipe,
-    TracerPipelineStage,
-)
+from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe
 from torch.distributed.pipelining._utils import PipeliningShapeError
 from torch.testing._internal.common_cuda import TEST_MULTIGPU
 from torch.testing._internal.common_distributed import (
@@ -91,11 +86,9 @@
             split_spec=split_spec,
         )
 
-        stage = TracerPipelineStage(
-            pipe,
+        stage = pipe.build_stage(
             self.rank,
             self.device,
-            chunks,  # to be cleaned
         )
 
         # Attach to a schedule
@@ -160,11 +153,9 @@
             mb_kwargs={"y": y_mb},
         )
 
-        stage = TracerPipelineStage(
-            pipe,
+        stage = pipe.build_stage(
             self.rank,
             self.device,
-            chunks,
         )
 
         # Attach to a schedule
diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py
index 2de3c4e..31632a8 100644
--- a/torch/distributed/pipelining/PipelineSchedule.py
+++ b/torch/distributed/pipelining/PipelineSchedule.py
@@ -4,17 +4,7 @@
 from abc import ABC, abstractmethod
 from collections import defaultdict
 from enum import Enum
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    List,
-    NamedTuple,
-    Optional,
-    Tuple,
-    TYPE_CHECKING,
-    Union,
-)
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
 
 import torch
 import torch.distributed as dist
@@ -23,9 +13,6 @@
 from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
 from .PipelineStage import _PipelineStageBase
 
-if TYPE_CHECKING:
-    from ._IR import Pipe
-
 
 __all__ = [
     "PipelineScheduleSingle",
@@ -84,8 +71,6 @@
 
         # Derived
         self._has_backward = self._loss_fn is not None
-        # To be filled by subclasses
-        self._pipe_info: Optional[Pipe.PipeInfo] = None
 
         # Holds the losses for each microbatch.
         self._internal_losses: List[torch.Tensor] = []
@@ -300,9 +285,6 @@
             kwargs_chunk_spec=kwargs_chunk_spec,
             output_merge_spec=output_merge_spec,
         )
-        self._pipe_info = (
-            stage.pipe_info if hasattr(stage, "pipe_info") else None  # type: ignore[attr-defined]
-        )
         # Self attributes
         self._stage = stage
         self._num_stages = stage.num_stages
@@ -595,9 +577,6 @@
             kwargs_chunk_spec=kwargs_chunk_spec,
             output_merge_spec=output_merge_spec,
         )
-        self._pipe_info = (
-            stages[0].pipe_info if hasattr(stages[0], "pipe_info") else None  # type: ignore[attr-defined]
-        )
         # Self attributes
         self._stages = stages
         self._num_stages = stages[0].num_stages
diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py
index f59e3e9..c18f91e 100644
--- a/torch/distributed/pipelining/PipelineStage.py
+++ b/torch/distributed/pipelining/PipelineStage.py
@@ -15,13 +15,16 @@
 
 from ._backward import stage_backward
 from ._debug import map_debug_info
-from ._IR import Pipe
-from ._utils import flatten_args, modify_graph_op_device, validate_tensors_metadata
+from ._utils import (
+    flatten_args,
+    modify_graph_op_device,
+    PipeInfo,
+    validate_tensors_metadata,
+)
 
 
 __all__ = [
     "PipelineStage",
-    "TracerPipelineStage",
 ]
 
 logger = logging.getLogger(__name__)
@@ -80,7 +83,8 @@
 class _PipelineStageBase(ABC):
     """
     Base class for pipeline stages.
-    Implements common methods used by both the `TracerPipelineStage` used by the tracing frontend and `PipelineStage`.
+    Defines or implements common methods used by the `_PipelineStage` used by
+    the tracing frontend and `PipelineStage` used by manual frontend.
     """
 
     def __init__(
@@ -97,7 +101,6 @@
             stage_index (int): The index of this stage.
             num_stages (int): The total number of stages in this pipeline.
             device (torch.device): The device to run this stage on.
-            num_microbatches (int): The number of microbatches to be run with this stage.
             group (Optional[dist.ProcessGroup]): The process group to use for communication.
                 If `None`, the default process group will be used.
                 Default: `None`.
@@ -641,9 +644,8 @@
         self,
         stage_module: torch.nn.Module,
         stage_index: int,
-        pipe_info: Pipe.PipeInfo,
+        pipe_info: PipeInfo,
         device: torch.device,
-        num_chunks: int,
         group: Optional[dist.ProcessGroup] = None,
     ):
         """
@@ -904,28 +906,6 @@
         return grad_recv_info_tuple
 
 
-# TODO: Update this to be returned by helper method under Pipe (kwen)
-class TracerPipelineStage(_PipelineStage):
-    def __init__(
-        self,
-        pipe: Pipe,
-        stage_index: int,
-        device: torch.device,
-        num_chunks: int,  # To be cleaned
-        group: Optional[dist.ProcessGroup] = None,
-    ):
-        """
-        Create a pipeline stage given a `Pipe` (representing the whole pipeline) and a stage index.
-        """
-        # Find my stage module
-        stage_module = pipe.get_stage_module(stage_index)
-        # Get my pipe info
-        pipe_info = pipe.info()
-        super().__init__(
-            stage_module, stage_index, pipe_info, device, num_chunks, group
-        )
-
-
 # Manual PipelineStage functions and definition
 
 METADATA_TENSOR_LEN = 100
diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py
index 9c3e21b..ba6b042 100644
--- a/torch/distributed/pipelining/_IR.py
+++ b/torch/distributed/pipelining/_IR.py
@@ -3,7 +3,6 @@
 import logging
 import operator
 from collections import defaultdict
-from dataclasses import dataclass
 from enum import Enum
 from inspect import Parameter, signature, Signature
 from types import MethodType
@@ -11,6 +10,7 @@
 
 import torch
 import torch.fx as fx
+from torch.distributed import ProcessGroup
 from torch.export import ExportedProgram
 from torch.export.unflatten import (
     _assign_attr,
@@ -20,9 +20,11 @@
 )
 from torch.fx.node import map_aggregate
 from torch.fx.passes.split_module import split_module
-
 from ._backward import _null_coalesce_accumulate, stage_backward
 from ._unflatten import _outline_submodules
+from ._utils import PipeInfo
+
+from .PipelineStage import _PipelineStage
 
 
 logger = logging.getLogger(__name__)
@@ -485,12 +487,6 @@
 
 
 class Pipe(torch.nn.Module):
-    @dataclass
-    class PipeInfo:
-        graph: fx.Graph
-        num_stages: int
-        has_loss_and_backward: bool
-
     def __init__(
         self,
         split_gm: fx.GraphModule,
@@ -505,7 +501,6 @@
         self.num_stages: int = num_stages
         self.has_loss_and_backward = has_loss_and_backward
         self.loss_spec = loss_spec
-        self.pipe_info: Optional[Pipe.PipeInfo] = None
 
         for node in split_gm.graph.nodes:
             assert (
@@ -1044,12 +1039,6 @@
             )
             submod0.recompile()
 
-        # Create pipe info
-        pipe.pipe_info = Pipe.PipeInfo(
-            graph=pipe.split_gm.graph,
-            num_stages=pipe.num_stages,
-            has_loss_and_backward=pipe.has_loss_and_backward,
-        )
         return pipe
 
     def __str__(self):
@@ -1058,12 +1047,31 @@
     def __repr__(self):
         return self.split_gm.__repr__()
 
-    def info(self) -> PipeInfo:
-        if self.pipe_info is None:
-            raise RuntimeError(
-                "Pipe info is not available. Please use the `pipeline` method to create the `Pipe` object."
-            )
-        return self.pipe_info
+    def _info(self) -> PipeInfo:
+        return PipeInfo(
+            graph=self.split_gm.graph,
+            num_stages=self.num_stages,
+            has_loss_and_backward=self.has_loss_and_backward,
+        )
+
+    def build_stage(
+        self,
+        stage_index: int,
+        device: torch.device,
+        group: Optional[ProcessGroup] = None,
+    ) -> _PipelineStage:
+        """
+        Create a pipeline stage given a stage index and distributed context.
+        """
+        # Find stage module
+        stage_module = self.get_stage_module(stage_index)
+        # Detach pipe info
+        # Note: be careful what's included in `pipe_info`. We don't want to keep
+        # a reference to `Pipe` or `Pipe.split_gm` which stops python from
+        # recycling them. When python recycles them, other stage modules (which
+        # are irrelevant to current rank) can be automatically freed.
+        pipe_info = self._info()
+        return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
 
 
 class SplitPoint(Enum):
diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py
index d9fd8fe..69e455e 100644
--- a/torch/distributed/pipelining/__init__.py
+++ b/torch/distributed/pipelining/__init__.py
@@ -6,14 +6,13 @@
     ScheduleInterleaved1F1B,
     ScheduleLoopedBFS,
 )
-from .PipelineStage import PipelineStage, TracerPipelineStage
+from .PipelineStage import PipelineStage
 
 __all__ = [
     "Pipe",
     "pipe_split",
     "SplitPoint",
     "pipeline",
-    "TracerPipelineStage",
     "PipelineStage",
     "Schedule1F1B",
     "ScheduleGPipe",
diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py
index f468053..72e96a3 100644
--- a/torch/distributed/pipelining/_utils.py
+++ b/torch/distributed/pipelining/_utils.py
@@ -1,5 +1,6 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
 import logging
+from dataclasses import dataclass
 from typing import List, Tuple, Union
 
 import torch
@@ -120,3 +121,14 @@
         validate_tensor_metadata(
             f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
         )
+
+
+@dataclass
+class PipeInfo:
+    """
+    Captures information for a pipeline (`Pipe` object).
+    """
+
+    graph: fx.Graph
+    num_stages: int
+    has_loss_and_backward: bool