| # Copyright (c) Meta Platforms, Inc. and affiliates |
| # Owner(s): ["oncall: distributed"] |
| from model_registry import MLPModule, ModelWithParamAlias |
| |
| import torch |
| from torch.distributed.pipelining import pipe_split, pipeline |
| from torch.testing._internal.common_utils import ( |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| TestCase, |
| ) |
| |
| |
| d_hid = 512 |
| microbatch_size = 16 |
| |
| torch.manual_seed(0) |
| |
| |
| # Basic example |
| class ExampleCode(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mm_param1 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.mm_param2 = torch.nn.Parameter(torch.randn(d_hid, d_hid)) |
| self.lin1 = torch.nn.Linear(d_hid, d_hid) |
| self.lin2 = torch.nn.Linear(d_hid, d_hid) |
| |
| def forward(self, x, y): |
| x = torch.mm(x, self.mm_param1) # mutli-use param |
| skip_connection = x |
| x = x + y |
| x = torch.relu(x) |
| pipe_split() |
| x = torch.mm(x, self.mm_param1) # mutli-use param |
| x = self.lin1(x) |
| pipe_split() |
| x = torch.relu(x) |
| x = x + skip_connection |
| x = torch.mm(x, self.mm_param2) |
| pipe_split() |
| x = self.lin2(x) |
| x = torch.relu(x) |
| return x |
| |
| |
| class MultiMLP(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mlp0 = MLPModule(d_hid) |
| self.mlp1 = MLPModule(d_hid) |
| self.mlp2 = MLPModule(d_hid) |
| self.mlp3 = MLPModule(d_hid) |
| |
| def forward(self, x, y): |
| x = self.mlp0(x) |
| pipe_split() |
| x = self.mlp1(x) |
| pipe_split() |
| x = self.mlp2(x) |
| pipe_split() |
| x = self.mlp3(x) |
| return x - y |
| |
| |
| EXPECTED_N_STAGES = { |
| ExampleCode: 4, |
| MultiMLP: 4, |
| ModelWithParamAlias: 2, |
| } |
| |
| # Currently, we don't enforce full set equality on the FQNs between the original |
| # and pipelined models, because in the multi-use param case, PP will deduplicate |
| # the FQNs from the state_dict. |
| # TODO |
| CHECK_FQN_SET_EQUALITY = False |
| |
| |
| class PipeTests(TestCase): |
| @parametrize("ModelClass", [ExampleCode, MultiMLP, ModelWithParamAlias]) |
| def test_model_split(self, ModelClass): |
| mod = ModelClass() |
| x = torch.randn(microbatch_size, d_hid) |
| y = torch.randn(microbatch_size, d_hid) |
| |
| pipe = pipeline( |
| mod, |
| mb_args=(x, y), |
| ) |
| |
| assert ( |
| pipe.num_stages == EXPECTED_N_STAGES[ModelClass] |
| ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" |
| |
| ref_out = mod(x, y) |
| out = pipe(x, y)[0] |
| torch.testing.assert_close(out, ref_out) |
| print(f"equivalence test passed {torch.sum(out)} ref {torch.sum(ref_out)}") |
| |
| # Check qualname |
| # state_dict.keys include both parameters and persistent buffers |
| old_names = set(mod.state_dict().keys()) |
| new_names = set() |
| for idx in range(pipe.num_stages): |
| stage_mod = pipe.get_stage_module(idx) |
| stage_fqns = set(stage_mod.state_dict().keys()) |
| assert stage_fqns.issubset(old_names) |
| new_names.update(stage_fqns) |
| |
| if CHECK_FQN_SET_EQUALITY: |
| assert ( |
| old_names == new_names |
| ), f""" |
| old names {old_names} |
| new names {new_names} |
| """ |
| print("Qualname check passed") |
| |
| |
| instantiate_parametrized_tests(PipeTests) |
| |
| if __name__ == "__main__": |
| run_tests() |