blob: a66f694abd21aed3007f35491ba047d2e1849318 [file] [log] [blame]
# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import logging
from typing import List
import torch
from torch.distributed.pipelining import (
ScheduleFlexibleInterleaved1F1B,
ScheduleInterleaved1F1B,
ScheduleLoopedBFS,
)
from torch.distributed.pipelining.schedules import (
_Action,
_add_send_recv,
_add_unshard_reshard,
_format_pipeline_order,
_PipelineSchedule,
_validate_pipeline_order,
B,
F,
get_schedule_class,
RECV_F,
RESHARD,
SEND_B,
UNSHARD,
W,
)
from torch.distributed.pipelining.stage import _PipelineStageBase
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
TestCase,
)
logger = logging.getLogger(__name__)
torch.manual_seed(0)
class MockPipelineStage(_PipelineStageBase):
def __init__(self, *args, **kwargs):
# Mock the necessary attributes
self.num_stages = kwargs.get("num_stages", 1)
self.group_size = kwargs.get("group_size", 1)
self.group_rank = kwargs.get("group_rank", 0)
self.group = kwargs.get("group", None)
self.stage_index_to_group_rank = kwargs.get("stage_index_to_group_rank", None)
def _create_grad_recv_info(self, *args, **kwargs):
return None
def _prepare_forward_infra(self, n_microbatches):
pass
def _prepare_backward_infra(self, n_microbatches):
pass
class ScheduleTest(TestCase):
def test_get_schedule_class(self):
# List of all expected schedule names
schedule_names = [
"1F1B",
"Interleaved1F1B",
"GPipe",
"FlexibleInterleaved1F1B",
"LoopedBFS",
"PipelineScheduleSingle",
"PipelineScheduleMulti",
]
# Test each schedule name
for name in schedule_names:
with self.subTest(name=name):
schedule_class = get_schedule_class(name)
self.assertIsNotNone(
schedule_class, f"Class for {name} should not be None"
)
self.assertTrue(
issubclass(schedule_class, _PipelineSchedule),
f"{name} should be a subclass of _PipelineSchedule",
)
class TestSchedulePlan(TestCase):
def setUp(self):
# Define a list of test cases with varying num_local_stages, num_microbatches, and group_size
# These should succeed since num_microbatches % group_size == 0
self.test_cases = [
# small number of stages
(2, 2, 2),
(2, 4, 4),
(2, 8, 2),
(2, 8, 4),
(2, 8, 8),
(4, 4, 4),
(4, 8, 4),
(4, 8, 8),
# large microbatches
(4, 16, 4),
(4, 32, 4),
(4, 64, 4),
# large groups
(4, 16, 16),
(4, 32, 32),
(4, 128, 64),
# odd num pipeline stages
(3, 2, 2),
(3, 8, 2),
(3, 12, 4),
# odd group_sizes
(4, 6, 3),
(4, 10, 5),
# n_mb non divisible by group_size
(2, 3, 4),
(2, 4, 4),
(2, 10, 4),
(2, 15, 4),
]
@parametrize(
"ScheduleClass",
[ScheduleInterleaved1F1B, ScheduleLoopedBFS],
)
def test_pipeline_order(self, ScheduleClass):
for num_local_stages, num_microbatches, group_size in self.test_cases:
with self.subTest(
num_local_stages=num_local_stages,
num_microbatches=num_microbatches,
group_size=group_size,
):
if num_microbatches % group_size != 0:
continue
logger.info(
"num_local_stages=%d num_microbatches=%d group_size=%d",
num_local_stages,
num_microbatches,
group_size,
)
num_stages = num_local_stages * group_size
stages = [
MockPipelineStage(group_size=group_size, num_stages=num_stages)
for i in range(num_local_stages)
]
schedule = ScheduleClass(stages, num_microbatches)
formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order
)
# print(formatted_pipeline_order)
_validate_pipeline_order(
schedule.pipeline_order, num_microbatches, num_stages
)
@parametrize(
"ScheduleClass",
[ScheduleFlexibleInterleaved1F1B],
)
def test_pipeline_order_flex_and_zero_bubble(self, ScheduleClass):
for num_local_stages, num_microbatches, group_size in self.test_cases:
with self.subTest(
num_local_stages=num_local_stages,
num_microbatches=num_microbatches,
group_size=group_size,
):
warmups_ops_last_stage = (num_local_stages - 1) * (
num_microbatches // max(1, num_microbatches // group_size)
)
warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1)
warmup_ops = min(warmup_ops, num_microbatches * num_local_stages)
for i in range(2):
num_stages = num_local_stages * group_size
stages = [
MockPipelineStage(group_size=group_size, num_stages=num_stages)
for i in range(num_local_stages)
]
schedule = ScheduleClass(
stages, num_microbatches, enable_zero_bubble=(i == 0)
)
formatted_pipeline_order = _format_pipeline_order(
schedule.pipeline_order
)
# print(formatted_pipeline_order)
_validate_pipeline_order(
schedule.pipeline_order,
num_microbatches,
num_stages,
enable_zero_bubble=(i == 0),
)
instantiate_parametrized_tests(TestSchedulePlan)
class TestScheduleLowering(TestCase):
"""Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules"""
def _parse_actions(self, actions: List[str]) -> List[_Action]:
return [_Action.from_str(s) for s in actions]
@parametrize(
"action_str_and_ref",
[
("1F0", _Action(1, F, 0)),
("2B1", _Action(2, B, 1)),
("0W3", _Action(0, W, 3)),
("1UNSHARD", _Action(1, UNSHARD, None)),
("3RESHARD", _Action(3, RESHARD, None)),
("2SEND_B2", _Action(2, SEND_B, 2)),
("1RECV_F1", _Action(1, RECV_F, 1)),
],
)
def test_action_parse(self, action_str_and_ref):
"""Test that actions can be parsed from strings and round-tripped back to the same strings."""
act_str, ref = action_str_and_ref
act = _Action.from_str(act_str)
self.assertEqual(act, ref)
self.assertEqual(act_str, act.__repr__())
@parametrize(
"test_info",
[
{
"compute": ["0F0", "0F1", " ", "0B0", "0B1"],
"comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"],
},
],
)
def test_unshard_reshard(self, test_info):
"""Test the lowering pass that takes a 'compute only' schedule (with only F,B,W ops) and adds
FSDP unshard/reshard operations to the schedule. This is just part of the process of adding communication
ops and producing a complete schedule.
"""
compute_sch = self._parse_actions(test_info["compute"])
expected_comms_sch = self._parse_actions(test_info["comms"])
comms_sch = _add_unshard_reshard(compute_sch)
for expected, actual in zip(expected_comms_sch, comms_sch):
self.assertEqual(
expected,
actual,
(
f"Mismatch: expected action {expected} but found {actual}."
f"\nWhole Schedule: {comms_sch}"
),
)
@parametrize(
"test_info",
[
{
"compute": {
0: ["0F0", "0F1", " ", "0B0", " ", "0B1"],
1: [" ", "1F0", "1B0", "1F1", "1B1", " "],
},
"comms": {
0: [
"0F0",
"0SEND_F0",
"0F1",
"0SEND_F1",
"0RECV_B0",
"0B0",
"0RECV_B1",
"0B1",
],
1: [
"1RECV_F0",
"1RECV_F1",
"1F0",
"1B0",
"1SEND_B0",
"1F1",
"1B1",
"1SEND_B1",
],
},
"stage_to_rank": lambda stage_idx: stage_idx,
"num_stages": 2,
},
],
)
def test_send_recv(self, test_info):
"""Tests the lowering pass that adds send/recv ops to a compute-only schedule."""
compute_sch = {
rank: self._parse_actions(test_info["compute"][rank])
for rank in test_info["compute"]
}
expected_comms_sch = {
rank: self._parse_actions(test_info["comms"][rank])
for rank in test_info["comms"]
}
comms_sch = _add_send_recv(
compute_sch, test_info["stage_to_rank"], test_info["num_stages"]
)
for rank in expected_comms_sch:
for i, (expected, actual) in enumerate(
zip(expected_comms_sch[rank], comms_sch[rank])
):
self.assertEqual(
expected,
actual,
(
f"Mismatch on rank {rank} at position {i}."
f"\nExpected: {expected_comms_sch[rank]}"
f"\nActual: {comms_sch[rank]}"
),
)
self.assertEqual(len(comms_sch[rank]), len(expected_comms_sch[rank]))
instantiate_parametrized_tests(TestScheduleLowering)
if __name__ == "__main__":
run_tests()