[pipelining] Add pipe.build_stage() (#128240)
Given `PipelineStage` name to manual side.
Thus adding a method under `Pipe` to create PipelineStage.
Moved `PipeInfo` to utils.py to avoid circular dependency between `_IR` and `PipelineStage`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128240
Approved by: https://github.com/wconstab, https://github.com/H-Huang
diff --git a/docs/source/distributed.pipelining.rst b/docs/source/distributed.pipelining.rst
index 05cefb2..32efef6 100644
--- a/docs/source/distributed.pipelining.rst
+++ b/docs/source/distributed.pipelining.rst
@@ -179,8 +179,7 @@
.. code-block:: python
- from torch.distributed.pipelining import TracerPipelineStage
- stage = TracerPipelineStage(pipe, stage_idx, device)
+ stage = pipe.build_stage(stage_idx, device, group)
.. note::
The ``pipeline`` frontend uses a tracer (``torch.export``) to capture your
@@ -354,8 +353,6 @@
.. autoclass:: PipelineStage
-.. autoclass:: TracerPipelineStage
-
Pipeline Schedules
==================
diff --git a/test/distributed/pipelining/test_schedule.py b/test/distributed/pipelining/test_schedule.py
index ef9a636..22d1167 100644
--- a/test/distributed/pipelining/test_schedule.py
+++ b/test/distributed/pipelining/test_schedule.py
@@ -19,7 +19,6 @@
ScheduleGPipe,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
- TracerPipelineStage,
)
from torch.distributed.pipelining.PipelineSchedule import _Action, _ComputationType
from torch.distributed.pipelining.PipelineStage import _PipelineStageBase
@@ -98,11 +97,9 @@
split_spec=split_spec,
)
- stage = TracerPipelineStage(
- pipe,
+ stage = pipe.build_stage(
self.rank,
self.device,
- chunks, # to be cleaned
)
# Attach to a schedule
@@ -140,11 +137,9 @@
mb_kwargs={"y": y_mb},
)
- stage = TracerPipelineStage(
- pipe,
+ stage = pipe.build_stage(
self.rank,
self.device,
- chunks, # to be cleaned
)
# Attach to a schedule
@@ -203,11 +198,9 @@
split_spec=split_spec,
)
- stage = TracerPipelineStage(
- pipe,
+ stage = pipe.build_stage(
self.rank,
self.device,
- chunks, # to be cleaned
)
# Attach to a schedule
diff --git a/test/distributed/pipelining/test_stage.py b/test/distributed/pipelining/test_stage.py
index b11a603..959dd2e 100644
--- a/test/distributed/pipelining/test_stage.py
+++ b/test/distributed/pipelining/test_stage.py
@@ -8,12 +8,7 @@
import torch
import torch.distributed as dist
-from torch.distributed.pipelining import (
- pipeline,
- PipelineStage,
- ScheduleGPipe,
- TracerPipelineStage,
-)
+from torch.distributed.pipelining import pipeline, PipelineStage, ScheduleGPipe
from torch.distributed.pipelining._utils import PipeliningShapeError
from torch.testing._internal.common_cuda import TEST_MULTIGPU
from torch.testing._internal.common_distributed import (
@@ -91,11 +86,9 @@
split_spec=split_spec,
)
- stage = TracerPipelineStage(
- pipe,
+ stage = pipe.build_stage(
self.rank,
self.device,
- chunks, # to be cleaned
)
# Attach to a schedule
@@ -160,11 +153,9 @@
mb_kwargs={"y": y_mb},
)
- stage = TracerPipelineStage(
- pipe,
+ stage = pipe.build_stage(
self.rank,
self.device,
- chunks,
)
# Attach to a schedule
diff --git a/torch/distributed/pipelining/PipelineSchedule.py b/torch/distributed/pipelining/PipelineSchedule.py
index 2de3c4e..31632a8 100644
--- a/torch/distributed/pipelining/PipelineSchedule.py
+++ b/torch/distributed/pipelining/PipelineSchedule.py
@@ -4,17 +4,7 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from enum import Enum
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- NamedTuple,
- Optional,
- Tuple,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
import torch
import torch.distributed as dist
@@ -23,9 +13,6 @@
from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
from .PipelineStage import _PipelineStageBase
-if TYPE_CHECKING:
- from ._IR import Pipe
-
__all__ = [
"PipelineScheduleSingle",
@@ -84,8 +71,6 @@
# Derived
self._has_backward = self._loss_fn is not None
- # To be filled by subclasses
- self._pipe_info: Optional[Pipe.PipeInfo] = None
# Holds the losses for each microbatch.
self._internal_losses: List[torch.Tensor] = []
@@ -300,9 +285,6 @@
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
)
- self._pipe_info = (
- stage.pipe_info if hasattr(stage, "pipe_info") else None # type: ignore[attr-defined]
- )
# Self attributes
self._stage = stage
self._num_stages = stage.num_stages
@@ -595,9 +577,6 @@
kwargs_chunk_spec=kwargs_chunk_spec,
output_merge_spec=output_merge_spec,
)
- self._pipe_info = (
- stages[0].pipe_info if hasattr(stages[0], "pipe_info") else None # type: ignore[attr-defined]
- )
# Self attributes
self._stages = stages
self._num_stages = stages[0].num_stages
diff --git a/torch/distributed/pipelining/PipelineStage.py b/torch/distributed/pipelining/PipelineStage.py
index f59e3e9..c18f91e 100644
--- a/torch/distributed/pipelining/PipelineStage.py
+++ b/torch/distributed/pipelining/PipelineStage.py
@@ -15,13 +15,16 @@
from ._backward import stage_backward
from ._debug import map_debug_info
-from ._IR import Pipe
-from ._utils import flatten_args, modify_graph_op_device, validate_tensors_metadata
+from ._utils import (
+ flatten_args,
+ modify_graph_op_device,
+ PipeInfo,
+ validate_tensors_metadata,
+)
__all__ = [
"PipelineStage",
- "TracerPipelineStage",
]
logger = logging.getLogger(__name__)
@@ -80,7 +83,8 @@
class _PipelineStageBase(ABC):
"""
Base class for pipeline stages.
- Implements common methods used by both the `TracerPipelineStage` used by the tracing frontend and `PipelineStage`.
+ Defines or implements common methods used by the `_PipelineStage` used by
+ the tracing frontend and `PipelineStage` used by manual frontend.
"""
def __init__(
@@ -97,7 +101,6 @@
stage_index (int): The index of this stage.
num_stages (int): The total number of stages in this pipeline.
device (torch.device): The device to run this stage on.
- num_microbatches (int): The number of microbatches to be run with this stage.
group (Optional[dist.ProcessGroup]): The process group to use for communication.
If `None`, the default process group will be used.
Default: `None`.
@@ -641,9 +644,8 @@
self,
stage_module: torch.nn.Module,
stage_index: int,
- pipe_info: Pipe.PipeInfo,
+ pipe_info: PipeInfo,
device: torch.device,
- num_chunks: int,
group: Optional[dist.ProcessGroup] = None,
):
"""
@@ -904,28 +906,6 @@
return grad_recv_info_tuple
-# TODO: Update this to be returned by helper method under Pipe (kwen)
-class TracerPipelineStage(_PipelineStage):
- def __init__(
- self,
- pipe: Pipe,
- stage_index: int,
- device: torch.device,
- num_chunks: int, # To be cleaned
- group: Optional[dist.ProcessGroup] = None,
- ):
- """
- Create a pipeline stage given a `Pipe` (representing the whole pipeline) and a stage index.
- """
- # Find my stage module
- stage_module = pipe.get_stage_module(stage_index)
- # Get my pipe info
- pipe_info = pipe.info()
- super().__init__(
- stage_module, stage_index, pipe_info, device, num_chunks, group
- )
-
-
# Manual PipelineStage functions and definition
METADATA_TENSOR_LEN = 100
diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py
index 9c3e21b..ba6b042 100644
--- a/torch/distributed/pipelining/_IR.py
+++ b/torch/distributed/pipelining/_IR.py
@@ -3,7 +3,6 @@
import logging
import operator
from collections import defaultdict
-from dataclasses import dataclass
from enum import Enum
from inspect import Parameter, signature, Signature
from types import MethodType
@@ -11,6 +10,7 @@
import torch
import torch.fx as fx
+from torch.distributed import ProcessGroup
from torch.export import ExportedProgram
from torch.export.unflatten import (
_assign_attr,
@@ -20,9 +20,11 @@
)
from torch.fx.node import map_aggregate
from torch.fx.passes.split_module import split_module
-
from ._backward import _null_coalesce_accumulate, stage_backward
from ._unflatten import _outline_submodules
+from ._utils import PipeInfo
+
+from .PipelineStage import _PipelineStage
logger = logging.getLogger(__name__)
@@ -485,12 +487,6 @@
class Pipe(torch.nn.Module):
- @dataclass
- class PipeInfo:
- graph: fx.Graph
- num_stages: int
- has_loss_and_backward: bool
-
def __init__(
self,
split_gm: fx.GraphModule,
@@ -505,7 +501,6 @@
self.num_stages: int = num_stages
self.has_loss_and_backward = has_loss_and_backward
self.loss_spec = loss_spec
- self.pipe_info: Optional[Pipe.PipeInfo] = None
for node in split_gm.graph.nodes:
assert (
@@ -1044,12 +1039,6 @@
)
submod0.recompile()
- # Create pipe info
- pipe.pipe_info = Pipe.PipeInfo(
- graph=pipe.split_gm.graph,
- num_stages=pipe.num_stages,
- has_loss_and_backward=pipe.has_loss_and_backward,
- )
return pipe
def __str__(self):
@@ -1058,12 +1047,31 @@
def __repr__(self):
return self.split_gm.__repr__()
- def info(self) -> PipeInfo:
- if self.pipe_info is None:
- raise RuntimeError(
- "Pipe info is not available. Please use the `pipeline` method to create the `Pipe` object."
- )
- return self.pipe_info
+ def _info(self) -> PipeInfo:
+ return PipeInfo(
+ graph=self.split_gm.graph,
+ num_stages=self.num_stages,
+ has_loss_and_backward=self.has_loss_and_backward,
+ )
+
+ def build_stage(
+ self,
+ stage_index: int,
+ device: torch.device,
+ group: Optional[ProcessGroup] = None,
+ ) -> _PipelineStage:
+ """
+ Create a pipeline stage given a stage index and distributed context.
+ """
+ # Find stage module
+ stage_module = self.get_stage_module(stage_index)
+ # Detach pipe info
+ # Note: be careful what's included in `pipe_info`. We don't want to keep
+ # a reference to `Pipe` or `Pipe.split_gm` which stops python from
+ # recycling them. When python recycles them, other stage modules (which
+ # are irrelevant to current rank) can be automatically freed.
+ pipe_info = self._info()
+ return _PipelineStage(stage_module, stage_index, pipe_info, device, group)
class SplitPoint(Enum):
diff --git a/torch/distributed/pipelining/__init__.py b/torch/distributed/pipelining/__init__.py
index d9fd8fe..69e455e 100644
--- a/torch/distributed/pipelining/__init__.py
+++ b/torch/distributed/pipelining/__init__.py
@@ -6,14 +6,13 @@
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
)
-from .PipelineStage import PipelineStage, TracerPipelineStage
+from .PipelineStage import PipelineStage
__all__ = [
"Pipe",
"pipe_split",
"SplitPoint",
"pipeline",
- "TracerPipelineStage",
"PipelineStage",
"Schedule1F1B",
"ScheduleGPipe",
diff --git a/torch/distributed/pipelining/_utils.py b/torch/distributed/pipelining/_utils.py
index f468053..72e96a3 100644
--- a/torch/distributed/pipelining/_utils.py
+++ b/torch/distributed/pipelining/_utils.py
@@ -1,5 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
+from dataclasses import dataclass
from typing import List, Tuple, Union
import torch
@@ -120,3 +121,14 @@
validate_tensor_metadata(
f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
)
+
+
+@dataclass
+class PipeInfo:
+ """
+ Captures information for a pipeline (`Pipe` object).
+ """
+
+ graph: fx.Graph
+ num_stages: int
+ has_loss_and_backward: bool