blob: 5c60a72424293aa34da6fbea5bf5aa00819dbaf4 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.profiler import record_function
from ._IR import Pipe
from ._PipelineStage import PipelineStageBase
from .microbatch import merge_chunks, split_args_kwargs_into_chunks
logger = logging.getLogger(__name__)
class PipelineSchedule(ABC):
def __init__(
self,
n_microbatches: int,
loss_fn: Optional[Callable[..., torch.Tensor]] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
# From arguments
self._n_microbatches = n_microbatches
self._loss_fn = loss_fn
self._output_merge_spec = output_merge_spec
# Derived
self._has_backward = self._loss_fn is not None
# To be filled by subclasses
self._pipe_info: Optional[Pipe.PipeInfo] = None
# Holds the losses for each microbatch.
self._internal_losses: List[torch.Tensor] = []
logger.info(f"Using {self.__class__.__name__}") # noqa: G004
def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
if stage.is_last and self._has_backward:
loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
self._internal_losses.append(loss)
logger.debug(
f"[{stage.stage_index}] Loss of microbatch {mb_index}: {loss}" # noqa: G004
)
def _maybe_get_loss(self, stage, mb_index):
valid_index = 0 <= mb_index < len(self._internal_losses)
if stage.is_last and self._has_backward and valid_index:
return self._internal_losses[mb_index]
elif len(self._internal_losses) != 0 and not valid_index:
raise RuntimeError(
f"Loss for microbatch {mb_index} is not available. "
f"Available losses for microbatches: {self._internal_losses}"
)
else:
return None
def _update_losses(self, stages, losses):
"""
Update the losses to those in the internal state
"""
# if stages not a list turn into a list
if not isinstance(stages, list):
stages = [stages]
contains_last_stage = any(stage.is_last for stage in stages)
# Return losses if there is a container passed in
if contains_last_stage and losses is not None:
if len(self._internal_losses) != self._n_microbatches:
raise RuntimeError(
f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
)
# Clean external container first
losses.clear()
# Copy internal losses to external container
losses.extend(self._internal_losses)
self._internal_losses.clear()
@abstractmethod
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the schedule
implementation.
Args:
microbatches: list of microbatch args.
"""
raise NotImplementedError
@abstractmethod
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
"""
Run one iteration of the pipeline schedule with *whole-batch* input.
Will chunk the input into microbatches automatically, and go through the
microbatches according to the schedule implementation.
args: positional arguments to the model (as in non-pipeline case).
kwargs: keyword arguments to the model (as in non-pipeline case).
target: target for the loss function.
losses: a list to store the losses for each microbatch.
"""
raise NotImplementedError
def _check_inputs(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Pre-process/check inputs
"""
def check_type_and_len(mbs, name: str):
if not isinstance(mbs, list):
raise TypeError(f"{name} must be a list but got a {type(mbs)}")
if len(mbs) != self._n_microbatches:
raise ValueError(
f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
)
if arg_mbs is not None:
check_type_and_len(arg_mbs, "arg_mbs")
else:
arg_mbs = [()] * self._n_microbatches
if kwarg_mbs is not None:
check_type_and_len(kwarg_mbs, "kwarg_mbs")
else:
kwarg_mbs = [{}] * self._n_microbatches
if target_mbs is not None:
check_type_and_len(target_mbs, "target_mbs")
if losses is not None:
if not isinstance(losses, list):
raise TypeError(f"losses must be a list but got a {type(losses)}")
return arg_mbs, kwarg_mbs
def _compute_loss(self, output, target):
return self._loss_fn(output, target) # type: ignore[misc]
def _split_inputs(
self,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
):
"""
Splits a full-batch input into chunks (i.e. microbatches) and returns
the chunks
"""
if self._pipe_info is not None:
# Use spec from `pipe_info`
args_chunk_spec = self._pipe_info.args_chunk_spec
kwargs_chunk_spec = self._pipe_info.kwargs_chunk_spec
else:
# Use default spec from `microbatch.py` (i.e. chunk dim 0 for each arg/kwarg)
args_chunk_spec = None
kwargs_chunk_spec = None
if args or kwargs:
args_split, kwargs_split = split_args_kwargs_into_chunks(
args,
kwargs,
self._n_microbatches,
args_chunk_spec,
kwargs_chunk_spec,
)
return args_split, kwargs_split
else:
# Empty inputs (e.g. when called on middle stages)
# Return a list of empty tuples/dicts with matching length as chunks
return [()] * self._n_microbatches, [{}] * self._n_microbatches
def _merge_outputs(self, output_chunks: List[Any]) -> Any:
"""
Merge output chunks back to a batch state.
If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
"""
return merge_chunks(
output_chunks,
self._output_merge_spec,
)
def _batch_p2p(p2p_ops: List[dist.P2POp], desc: Optional[str] = None):
"""
Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
"""
desc_str = f"{desc}, " if desc else ""
logger.debug(f"batch_p2p {desc_str}{p2p_ops}") # noqa: G004
return dist.batch_isend_irecv(p2p_ops).pop()
def _sorted_batch_p2p(
p2p_ops: List[dist.P2POp], desc: Optional[str] = None
) -> Dict[int, dist.Work]:
"""
Sorts the list of P2P ops by the peer rank, and then calls
batch_isend_irecv. Return a dictionary of works by peer rank. This function
helps us avoid hangs in case of skip connections.
"""
# Arrange p2p_ops by peer rank:
# int is the peer rank;
# List is the list of ops towards the peer
ops_by_peer: Dict[int, List[dist.P2POp]] = defaultdict(list)
work_by_peer: Dict[int, dist.Work] = {}
if len(p2p_ops) == 0:
return work_by_peer
# Classify the ops by peer rank
for op in p2p_ops:
ops_by_peer[op.peer].append(op)
# Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
for peer, ops in sorted(ops_by_peer.items()):
work_by_peer[peer] = _batch_p2p(ops, desc=desc)
return work_by_peer
class PipelineScheduleSingle(PipelineSchedule):
"""
Base class for single-stage schedules.
Implements the `step` method.
Derived classes should implement `_step_microbatches`.
"""
def __init__(
self,
stage: PipelineStageBase,
n_microbatches: int,
loss_fn: Optional[Callable] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
# Init parent
super().__init__(
n_microbatches=n_microbatches,
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
)
self._pipe_info = (
stage.pipe_info if hasattr(stage, "pipe_info") else None # type: ignore[attr-defined]
)
# Self attributes
self._stage = stage
self._num_stages = stage.num_stages
# Set the same has_backward flag for stage object
self._stage.has_backward = self._has_backward
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
# Clean per iteration
self._stage.clear_runtime_states()
# Split inputs into microbatches
args_split, kwargs_split = self._split_inputs(args, kwargs)
# Split target into microbatches
if target is not None:
targets_split = list(torch.tensor_split(target, self._n_microbatches))
else:
targets_split = None
# Run microbatches
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
# Return merged results per original format
if self._stage.is_last:
return self._merge_outputs(self._stage.output_chunks)
else:
return None
class ScheduleGPipe(PipelineScheduleSingle):
"""
The GPipe schedule.
Will go through all the microbatches in a fill-drain manner.
"""
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the GPipe schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
# Run microbatches
for i in range(self._n_microbatches):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops()
works = _sorted_batch_p2p(ops, desc="fwd_recv")
for work in works.values():
work.wait()
output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
ops = self._stage.get_fwd_send_ops()
works = _sorted_batch_p2p(ops, desc="fwd_send")
fwd_sends_to_wait.extend(works.values())
logger.debug(
f"[{self._stage.stage_index}] Forwarded microbatch {i}" # noqa: G004
)
self._maybe_compute_loss(self._stage, output, target_mbs, i)
# Wait for all forward sends to finish
# This should not have performance impact because by the time the first
# backward arrives all the forward sends should have been finished.
for work in fwd_sends_to_wait:
work.wait()
# No loss function, no need to run backward
if not self._has_backward:
return
# Run backward
# Delay send waits
bwd_sends_to_wait: List[dist.Work] = []
for i in range(self._n_microbatches):
# set library-specific data-parallel config flags to ensure gradient accumulation across microbatches
self._stage._configure_data_parallel_mode(i == self._n_microbatches - 1)
with record_function(f"Backward {i}"):
ops = self._stage.get_bwd_recv_ops()
works = _sorted_batch_p2p(ops, desc="bwd_recv")
for work in works.values():
work.wait()
loss = self._maybe_get_loss(self._stage, i)
self._stage.backward_one_chunk(loss=loss)
ops = self._stage.get_bwd_send_ops()
works = _sorted_batch_p2p(ops, desc="bwd_send")
bwd_sends_to_wait.extend(works.values())
logger.debug(
f"[{self._stage.stage_index}] Backwarded microbatch {i}" # noqa: G004
)
# Return losses if there is a container passed in
self._update_losses(self._stage, losses)
# Wait for all backward sends to finish
for work in bwd_sends_to_wait:
work.wait()
class Schedule1F1B(PipelineScheduleSingle):
"""
The 1F1B schedule.
Will perform one forward and one backward on the microbatches in steady state.
"""
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the 1F1B schedule.
Args:
microbatches: list of microbatch args.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
# forward for num_microbatches + backward for num_microbatches
total_ops = self._n_microbatches * 2
# Example, 4 GPUs, 8 microbatches
# Stage 0: 6 warmup, 2 1f1b, 6 cooldown
# Stage 1: 4 warmup, 4 1f1b, 4 cooldown
# Stage 2: 2 warmup, 6 1f1b, 2 cooldown
# Stage 3: 0 warmup, 8 1f1b, 0 cooldown
# fwd only
warmup_steps = min(
self._n_microbatches,
2 * (self._num_stages - self._stage.stage_index - 1),
)
# fwd + bwd
main_1f1b_steps = self._n_microbatches - warmup_steps
# bwd only
cooldown_steps = total_ops - (warmup_steps + (2 * main_1f1b_steps))
total_steps = warmup_steps + main_1f1b_steps + cooldown_steps
logger.debug(
f"Stage {self._stage.stage_index}: " # noqa: G004
f"Warmup steps: {warmup_steps}, "
f"Main 1F1B steps: {main_1f1b_steps}, "
f"Cooldown steps: {cooldown_steps}, "
f"Total steps: {total_steps}"
)
# Delay send waits
fwd_sends_to_wait: List[dist.Work] = []
bwd_sends_to_wait: List[dist.Work] = []
def is_forward_step(i):
assert i >= 0, i
return i < self._n_microbatches
def is_backward_step(i):
assert i < total_steps, i
return i >= warmup_steps and self._has_backward
def is_1f1b_step(i):
return is_forward_step(i) and is_backward_step(i)
def is_warmup_step(i):
return is_forward_step(i) and not is_backward_step(i)
def is_cooldown_step(i):
return not is_forward_step(i) and is_backward_step(i)
def should_coalesce_fwd_send_bwd_recv(fwd_send_i):
return (
is_1f1b_step(fwd_send_i)
or (is_warmup_step(fwd_send_i) and is_cooldown_step(fwd_send_i + 1))
or (
fwd_send_i >= 1
and is_warmup_step(fwd_send_i - 1)
and is_cooldown_step(fwd_send_i)
)
)
def should_coalesce_bwd_send_fwd_recv(bwd_send_i):
# The backward send to prev stage should be coalesced with the fwd recv from the previous stage
return bwd_send_i >= warmup_steps and is_1f1b_step(bwd_send_i + 1)
# bwd chunk counter
bwd_mb_index = 0
self._stage._configure_data_parallel_mode(last_backward=False)
for i in range(total_steps):
if is_forward_step(i):
with record_function(f"Forward {i}"):
ops = self._stage.get_fwd_recv_ops()
desc = "fwd_recv"
if should_coalesce_bwd_send_fwd_recv(i - 1):
desc += "_bwd_send"
ops.extend(self._stage.get_bwd_send_ops())
works = _sorted_batch_p2p(ops, desc=desc)
for work in works.values():
work.wait()
output = self._stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
if not should_coalesce_fwd_send_bwd_recv(i):
ops = self._stage.get_fwd_send_ops()
works = _sorted_batch_p2p(ops, desc="fwd_send")
fwd_sends_to_wait.extend(works.values())
self._maybe_compute_loss(self._stage, output, target_mbs, i)
if is_backward_step(i):
self._stage._configure_data_parallel_mode(
last_backward=(i == total_steps - 1)
)
with record_function(f"Backward {bwd_mb_index}"):
ops = self._stage.get_bwd_recv_ops()
desc = "bwd_recv"
if should_coalesce_fwd_send_bwd_recv(i):
ops.extend(self._stage.get_fwd_send_ops())
desc += "_fwd_send"
works = _sorted_batch_p2p(ops, desc=desc)
for work in works.values():
work.wait()
loss = self._maybe_get_loss(self._stage, bwd_mb_index)
self._stage.backward_one_chunk(loss=loss)
if not should_coalesce_bwd_send_fwd_recv(i):
# see Note: coalesced bwd-send/fwd-recv
ops = self._stage.get_bwd_send_ops()
works = _sorted_batch_p2p(ops, desc="bwd_send")
bwd_sends_to_wait.extend(works.values())
bwd_mb_index += 1
# Wait for all forward sends to finish
for work in fwd_sends_to_wait:
work.wait()
# Wait for all backward sends to finish
for work in bwd_sends_to_wait:
work.wait()
# Return losses if there is a container passed in
self._update_losses(self._stage, losses)
class PipelineScheduleMulti(PipelineSchedule):
"""
Base class for multi-stage schedules.
Implements the `step` method.
Derived classes should implement `_step_microbatches`.
"""
def __init__(
self,
stages: List[PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
if len(stages) <= 1:
raise ValueError(
f"Multi-stage schedule expects at least two stages but got {len(stages)}"
)
# Init parent
super().__init__(
n_microbatches=n_microbatches,
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
)
self._pipe_info = (
stages[0].pipe_info if hasattr(stages[0], "pipe_info") else None # type: ignore[attr-defined]
)
# Self attributes
self._stages = stages
self._num_stages = stages[0].num_stages
# Set the same has_backward flag for stage object
for stage in self._stages:
stage.has_backward = self._has_backward
self._should_compute_loss = (
lambda stage: stage.is_last and self._loss_fn is not None
)
def step(self, *args, target=None, losses: Optional[List] = None, **kwargs):
# Clean per iteration
for stage in self._stages:
stage.clear_runtime_states()
# Split inputs into microbatches
args_split, kwargs_split = self._split_inputs(args, kwargs)
# Split target into microbatches
if target is not None:
targets_split = list(torch.tensor_split(target, self._n_microbatches))
else:
targets_split = None
# Run microbatches
self._step_microbatches(args_split, kwargs_split, targets_split, losses)
# Return merged results per original format
for stage in self._stages:
if stage.is_last:
return self._merge_outputs(stage.output_chunks)
# Does not contain the last stage
return None
class ScheduleLoopedBFS(PipelineScheduleMulti):
"""
Breadth-First Pipeline Parallelism.
See https://arxiv.org/abs/2211.05953 for details.
Simliar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
What is different is that when microbatches are ready for multiple local
stages, Loops BFS will prioritizes the earlier stage, running all available
microbatches at once.
"""
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None, # TODO
losses: Optional[List] = None, # TODO
):
"""
Run one iteration of the pipeline schedule with list of microbatches.
Will go through all the microbatches according to the Looped BFS schedule.
Args:
microbatches: list of microbatch args.
"""
# Pre-process inputs
if arg_mbs is not None:
# TODO: fix this so it is preset
self._n_microbatches = len(arg_mbs)
assert len(arg_mbs) == self._n_microbatches
else:
arg_mbs = [()] * self._n_microbatches
if kwarg_mbs is not None:
assert len(kwarg_mbs) == self._n_microbatches
else:
kwarg_mbs = [{}] * self._n_microbatches
for stage in self._stages:
for i in range(self._n_microbatches):
with record_function(f"Stage {stage.stage_index} Forward"):
ops = stage.get_fwd_recv_ops()
if ops:
_batch_p2p(ops, desc="fwd_recv").wait()
output = stage.forward_one_chunk(arg_mbs[i], kwarg_mbs[i])
self._maybe_compute_loss(stage, output, target_mbs, i)
ops = stage.get_fwd_send_ops()
if ops:
_batch_p2p(ops, desc="fwd_send")
for stage in reversed(self._stages):
for i in range(self._n_microbatches):
stage._configure_data_parallel_mode(i == self._n_microbatches - 1)
with record_function(f"Stage {stage.stage_index} Backward"):
ops = stage.get_bwd_recv_ops()
if ops:
_batch_p2p(ops, desc="bwd_recv").wait()
loss = self._maybe_get_loss(stage, i)
stage.backward_one_chunk(loss=loss)
ops = stage.get_bwd_send_ops()
if ops:
_batch_p2p(ops, desc="bwd_send")
self._update_losses(self._stages, losses)
class ScheduleInterleaved1F1B(PipelineScheduleMulti):
"""
The Interleaved 1F1B schedule.
Will perform one forward and one backward on the microbatches in steady
state and supports multiple stages per rank. When microbatches are ready for
multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
(also called "depth first").
"""
def __init__(
self,
stages: List[PipelineStageBase],
n_microbatches: int,
loss_fn: Optional[Callable] = None,
output_merge_spec: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
):
self.pp_group_size = stages[0].group_size
# TODO: is this limitation a must?
if n_microbatches % self.pp_group_size != 0:
raise ValueError(
"Interleaved 1F1B requires the number of microbatches to be a "
f"multiple of the number of pipeline ranks ({self.pp_group_size}), "
f"but got {n_microbatches}."
)
super().__init__(
stages=stages,
n_microbatches=n_microbatches,
loss_fn=loss_fn,
output_merge_spec=output_merge_spec,
)
self.n_local_stages = len(stages)
self.rank = stages[0].group_rank
def _step_microbatches(
self,
arg_mbs: Optional[List] = None,
kwarg_mbs: Optional[List] = None,
target_mbs: Optional[List] = None,
losses: Optional[List] = None,
):
"""
Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf).
Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks
and each rank away from highest rank adds 2 warmup steps due to:
- one happened before highest rank's warmup started,
- one waiting for backward result to trickle down from highest rank
TODO: Interleaved 1F1B does not support using _sorted_batch_p2p()
because it requires recvs and sends from different peers
to execute in the same coalesced operation. As a result, this schedule does
not support models with skip connections.
"""
arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
# increment warmup_steps by 2 for each hop away
warmup_steps = (self.n_local_stages - 1) * self.pp_group_size
warmup_steps += 2 * ((self.pp_group_size - 1) - self.rank)
warmup_steps = min(warmup_steps, self._n_microbatches * self.n_local_stages)
fwd_bwd_steps = (self.n_local_stages * self._n_microbatches) - warmup_steps
cooldown_steps = (self.n_local_stages * self._n_microbatches) - fwd_bwd_steps
assert (
warmup_steps + fwd_bwd_steps * 2 + cooldown_steps
== self.n_local_stages * self._n_microbatches * 2
)
total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
logger.debug(
f"rank {self.rank}, warmup_steps {warmup_steps}, " # noqa: G004
f"1f1b {fwd_bwd_steps}, cooldown_steps {cooldown_steps}"
)
def forward_stage_local_index(step):
return (step // self.pp_group_size) % self.n_local_stages
def backward_stage_local_index(step):
return (
self.n_local_stages
- 1
- ((step - warmup_steps) // self.pp_group_size) % self.n_local_stages
)
fwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int)
bwd_stage_mb_index: Dict[PipelineStageBase, int] = defaultdict(int)
# Delay send waits
sends_to_wait: List[dist.Work] = []
# Store ops (potentially across steps)
ops: List[dist.P2POp] = []
# Warmup Phase (forward only)
for step in range(warmup_steps):
fwd_stage = self._stages[forward_stage_local_index(step)]
# This will assign the current microbatch index and update it for future steps
fwd_stage_mb_index[fwd_stage] = (
mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
logger.debug(
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, {mb_index=}" # noqa: G004
)
with record_function(f"Forward {step}"):
ops.extend(fwd_stage.get_fwd_recv_ops())
if ops:
work = _batch_p2p(ops, desc="warmup_pre_fwd")
work.wait()
ops.clear()
output = fwd_stage.forward_one_chunk(arg_mbs[mb_index], kwarg_mbs[mb_index]) # type: ignore[index]
ops.extend(fwd_stage.get_fwd_send_ops())
# If we are right before the fwd-bwd step, then we need to delay the send to the next step,
# This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang.
# In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed
if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0):
work = _batch_p2p(ops, desc="warmup_post_fwd")
sends_to_wait.append(work)
ops.clear()
self._maybe_compute_loss(fwd_stage, output, target_mbs, mb_index)
# 1F1B Phase (forward and backward)
for step in range(warmup_steps, warmup_steps + fwd_bwd_steps):
fwd_stage = self._stages[forward_stage_local_index(step)]
bwd_stage = self._stages[backward_stage_local_index(step)]
fwd_stage_mb_index[fwd_stage] = (
fwd_mb_index := fwd_stage_mb_index[fwd_stage]
) + 1
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"Rank {self.rank}: {step=}, {fwd_stage.stage_index=}, " # noqa: G004
f"{bwd_stage.stage_index=}, {fwd_mb_index=}, {bwd_mb_index=}"
)
desc = f"1F1B {step}"
with record_function(desc):
ops.extend(fwd_stage.get_fwd_recv_ops())
ops.extend(bwd_stage.get_bwd_recv_ops())
if ops:
work = _batch_p2p(ops, desc=desc)
work.wait()
ops.clear()
# Forward
output = fwd_stage.forward_one_chunk(arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]) # type: ignore[index]
ops.extend(fwd_stage.get_fwd_send_ops())
self._maybe_compute_loss(fwd_stage, output, target_mbs, fwd_mb_index)
# Backward
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())
# Cooldown Phase (backward only)
for step in range(warmup_steps + fwd_bwd_steps, total_steps):
bwd_stage = self._stages[backward_stage_local_index(step)]
bwd_stage_mb_index[bwd_stage] = (
bwd_mb_index := bwd_stage_mb_index[bwd_stage]
) + 1
bwd_stage._configure_data_parallel_mode(
bwd_mb_index == self._n_microbatches - 1
)
logger.debug(
f"Rank {self.rank}: {step=}, {bwd_stage.stage_index=}, {bwd_mb_index=}" # noqa: G004
)
desc = f"Cooldown {step}"
with record_function(desc):
ops.extend(bwd_stage.get_bwd_recv_ops())
if ops:
work = _batch_p2p(ops, desc=desc + " pre_bwd")
work.wait()
ops.clear()
loss = self._maybe_get_loss(bwd_stage, bwd_mb_index)
bwd_stage.backward_one_chunk(loss=loss)
ops.extend(bwd_stage.get_bwd_send_ops())
if ops:
work = _batch_p2p(ops, desc=desc + " post_bwd")
sends_to_wait.append(work)
ops.clear()
# Make sure all sends are finished
for work in sends_to_wait:
work.wait()
# Return losses if there is a container passed in
self._update_losses(self._stages, losses)