| # Copyright (c) Meta Platforms, Inc. and affiliates |
| # Owner(s): ["oncall: distributed"] |
| import logging |
| from typing import List |
| |
| import torch |
| from torch.distributed.pipelining import ( |
| ScheduleFlexibleInterleaved1F1B, |
| ScheduleInterleaved1F1B, |
| ScheduleLoopedBFS, |
| ) |
| from torch.distributed.pipelining.schedules import ( |
| _Action, |
| _add_send_recv, |
| _add_unshard_reshard, |
| _format_pipeline_order, |
| _PipelineSchedule, |
| _validate_pipeline_order, |
| B, |
| F, |
| get_schedule_class, |
| RECV_F, |
| RESHARD, |
| SEND_B, |
| UNSHARD, |
| W, |
| ) |
| from torch.distributed.pipelining.stage import _PipelineStageBase |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| TestCase, |
| ) |
| |
| |
| logger = logging.getLogger(__name__) |
| torch.manual_seed(0) |
| |
| |
| class MockPipelineStage(_PipelineStageBase): |
| def __init__(self, *args, **kwargs): |
| # Mock the necessary attributes |
| self.num_stages = kwargs.get("num_stages", 1) |
| self.group_size = kwargs.get("group_size", 1) |
| self.group_rank = kwargs.get("group_rank", 0) |
| self.group = kwargs.get("group", None) |
| self.stage_index_to_group_rank = kwargs.get("stage_index_to_group_rank", None) |
| |
| def _create_grad_recv_info(self, *args, **kwargs): |
| return None |
| |
| def _prepare_forward_infra(self, n_microbatches): |
| pass |
| |
| def _prepare_backward_infra(self, n_microbatches): |
| pass |
| |
| |
| class ScheduleTest(TestCase): |
| def test_get_schedule_class(self): |
| # List of all expected schedule names |
| schedule_names = [ |
| "1F1B", |
| "Interleaved1F1B", |
| "GPipe", |
| "FlexibleInterleaved1F1B", |
| "LoopedBFS", |
| "PipelineScheduleSingle", |
| "PipelineScheduleMulti", |
| ] |
| |
| # Test each schedule name |
| for name in schedule_names: |
| with self.subTest(name=name): |
| schedule_class = get_schedule_class(name) |
| self.assertIsNotNone( |
| schedule_class, f"Class for {name} should not be None" |
| ) |
| self.assertTrue( |
| issubclass(schedule_class, _PipelineSchedule), |
| f"{name} should be a subclass of _PipelineSchedule", |
| ) |
| |
| |
| class TestSchedulePlan(TestCase): |
| def setUp(self): |
| # Define a list of test cases with varying num_local_stages, num_microbatches, and group_size |
| # These should succeed since num_microbatches % group_size == 0 |
| self.test_cases = [ |
| # small number of stages |
| (2, 2, 2), |
| (2, 4, 4), |
| (2, 8, 2), |
| (2, 8, 4), |
| (2, 8, 8), |
| (4, 4, 4), |
| (4, 8, 4), |
| (4, 8, 8), |
| # large microbatches |
| (4, 16, 4), |
| (4, 32, 4), |
| (4, 64, 4), |
| # large groups |
| (4, 16, 16), |
| (4, 32, 32), |
| (4, 128, 64), |
| # odd num pipeline stages |
| (3, 2, 2), |
| (3, 8, 2), |
| (3, 12, 4), |
| # odd group_sizes |
| (4, 6, 3), |
| (4, 10, 5), |
| # n_mb non divisible by group_size |
| (2, 3, 4), |
| (2, 4, 4), |
| (2, 10, 4), |
| (2, 15, 4), |
| ] |
| |
| @parametrize( |
| "ScheduleClass", |
| [ScheduleInterleaved1F1B, ScheduleLoopedBFS], |
| ) |
| def test_pipeline_order(self, ScheduleClass): |
| for num_local_stages, num_microbatches, group_size in self.test_cases: |
| with self.subTest( |
| num_local_stages=num_local_stages, |
| num_microbatches=num_microbatches, |
| group_size=group_size, |
| ): |
| if num_microbatches % group_size != 0: |
| continue |
| |
| logger.info( |
| "num_local_stages=%d num_microbatches=%d group_size=%d", |
| num_local_stages, |
| num_microbatches, |
| group_size, |
| ) |
| num_stages = num_local_stages * group_size |
| stages = [ |
| MockPipelineStage(group_size=group_size, num_stages=num_stages) |
| for i in range(num_local_stages) |
| ] |
| |
| schedule = ScheduleClass(stages, num_microbatches) |
| formatted_pipeline_order = _format_pipeline_order( |
| schedule.pipeline_order |
| ) |
| # print(formatted_pipeline_order) |
| _validate_pipeline_order( |
| schedule.pipeline_order, num_microbatches, num_stages |
| ) |
| |
| @parametrize( |
| "ScheduleClass", |
| [ScheduleFlexibleInterleaved1F1B], |
| ) |
| def test_pipeline_order_flex_and_zero_bubble(self, ScheduleClass): |
| for num_local_stages, num_microbatches, group_size in self.test_cases: |
| with self.subTest( |
| num_local_stages=num_local_stages, |
| num_microbatches=num_microbatches, |
| group_size=group_size, |
| ): |
| warmups_ops_last_stage = (num_local_stages - 1) * ( |
| num_microbatches // max(1, num_microbatches // group_size) |
| ) |
| warmup_ops = warmups_ops_last_stage + 2 * (group_size - 1) |
| warmup_ops = min(warmup_ops, num_microbatches * num_local_stages) |
| |
| for i in range(2): |
| num_stages = num_local_stages * group_size |
| stages = [ |
| MockPipelineStage(group_size=group_size, num_stages=num_stages) |
| for i in range(num_local_stages) |
| ] |
| schedule = ScheduleClass( |
| stages, num_microbatches, enable_zero_bubble=(i == 0) |
| ) |
| formatted_pipeline_order = _format_pipeline_order( |
| schedule.pipeline_order |
| ) |
| # print(formatted_pipeline_order) |
| _validate_pipeline_order( |
| schedule.pipeline_order, |
| num_microbatches, |
| num_stages, |
| enable_zero_bubble=(i == 0), |
| ) |
| |
| |
| instantiate_parametrized_tests(TestSchedulePlan) |
| |
| |
| class TestScheduleLowering(TestCase): |
| """Tests lowering passes that convert simple compute-only (FBW) schedules into compute+comms schedules""" |
| |
| def _parse_actions(self, actions: List[str]) -> List[_Action]: |
| return [_Action.from_str(s) for s in actions] |
| |
| @parametrize( |
| "action_str_and_ref", |
| [ |
| ("1F0", _Action(1, F, 0)), |
| ("2B1", _Action(2, B, 1)), |
| ("0W3", _Action(0, W, 3)), |
| ("1UNSHARD", _Action(1, UNSHARD, None)), |
| ("3RESHARD", _Action(3, RESHARD, None)), |
| ("2SEND_B2", _Action(2, SEND_B, 2)), |
| ("1RECV_F1", _Action(1, RECV_F, 1)), |
| ], |
| ) |
| def test_action_parse(self, action_str_and_ref): |
| """Test that actions can be parsed from strings and round-tripped back to the same strings.""" |
| act_str, ref = action_str_and_ref |
| act = _Action.from_str(act_str) |
| self.assertEqual(act, ref) |
| self.assertEqual(act_str, act.__repr__()) |
| |
| @parametrize( |
| "test_info", |
| [ |
| { |
| "compute": ["0F0", "0F1", " ", "0B0", "0B1"], |
| "comms": ["0UNSHARD", "0F0", "0F1", "0B0", "0B1", "0RESHARD"], |
| }, |
| ], |
| ) |
| def test_unshard_reshard(self, test_info): |
| """Test the lowering pass that takes a 'compute only' schedule (with only F,B,W ops) and adds |
| FSDP unshard/reshard operations to the schedule. This is just part of the process of adding communication |
| ops and producing a complete schedule. |
| """ |
| compute_sch = self._parse_actions(test_info["compute"]) |
| expected_comms_sch = self._parse_actions(test_info["comms"]) |
| |
| comms_sch = _add_unshard_reshard(compute_sch) |
| for expected, actual in zip(expected_comms_sch, comms_sch): |
| self.assertEqual( |
| expected, |
| actual, |
| ( |
| f"Mismatch: expected action {expected} but found {actual}." |
| f"\nWhole Schedule: {comms_sch}" |
| ), |
| ) |
| |
| @parametrize( |
| "test_info", |
| [ |
| { |
| "compute": { |
| 0: ["0F0", "0F1", " ", "0B0", " ", "0B1"], |
| 1: [" ", "1F0", "1B0", "1F1", "1B1", " "], |
| }, |
| "comms": { |
| 0: [ |
| "0F0", |
| "0SEND_F0", |
| "0F1", |
| "0SEND_F1", |
| "0RECV_B0", |
| "0B0", |
| "0RECV_B1", |
| "0B1", |
| ], |
| 1: [ |
| "1RECV_F0", |
| "1RECV_F1", |
| "1F0", |
| "1B0", |
| "1SEND_B0", |
| "1F1", |
| "1B1", |
| "1SEND_B1", |
| ], |
| }, |
| "stage_to_rank": lambda stage_idx: stage_idx, |
| "num_stages": 2, |
| }, |
| ], |
| ) |
| def test_send_recv(self, test_info): |
| """Tests the lowering pass that adds send/recv ops to a compute-only schedule.""" |
| compute_sch = { |
| rank: self._parse_actions(test_info["compute"][rank]) |
| for rank in test_info["compute"] |
| } |
| expected_comms_sch = { |
| rank: self._parse_actions(test_info["comms"][rank]) |
| for rank in test_info["comms"] |
| } |
| |
| comms_sch = _add_send_recv( |
| compute_sch, test_info["stage_to_rank"], test_info["num_stages"] |
| ) |
| for rank in expected_comms_sch: |
| for i, (expected, actual) in enumerate( |
| zip(expected_comms_sch[rank], comms_sch[rank]) |
| ): |
| self.assertEqual( |
| expected, |
| actual, |
| ( |
| f"Mismatch on rank {rank} at position {i}." |
| f"\nExpected: {expected_comms_sch[rank]}" |
| f"\nActual: {comms_sch[rank]}" |
| ), |
| ) |
| self.assertEqual(len(comms_sch[rank]), len(expected_comms_sch[rank])) |
| |
| |
| instantiate_parametrized_tests(TestScheduleLowering) |
| |
| if __name__ == "__main__": |
| run_tests() |