|  | # Owner(s): ["oncall: distributed"] | 
|  |  | 
|  | import sys | 
|  |  | 
|  | import torch | 
|  | from torch import distributed as dist | 
|  | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | 
|  | from torch.nn import Linear, Module | 
|  | from torch.optim import SGD | 
|  | from torch.testing._internal.common_distributed import skip_if_lt_x_gpu | 
|  | from torch.testing._internal.common_fsdp import FSDPTest | 
|  | from torch.testing._internal.common_utils import ( | 
|  | instantiate_parametrized_tests, | 
|  | parametrize, | 
|  | run_tests, | 
|  | subtest, | 
|  | TEST_WITH_DEV_DBG_ASAN, | 
|  | ) | 
|  |  | 
|  | if not dist.is_available(): | 
|  | print("Distributed not available, skipping tests", file=sys.stderr) | 
|  | sys.exit(0) | 
|  |  | 
|  | if TEST_WITH_DEV_DBG_ASAN: | 
|  | print( | 
|  | "Skip dev-asan as torch + multiprocessing spawn have known issues", | 
|  | file=sys.stderr, | 
|  | ) | 
|  | sys.exit(0) | 
|  |  | 
|  |  | 
|  | class TestInput(FSDPTest): | 
|  | @property | 
|  | def world_size(self): | 
|  | return 1 | 
|  |  | 
|  | @skip_if_lt_x_gpu(1) | 
|  | @parametrize("input_cls", [subtest(dict, name="dict"), subtest(list, name="list")]) | 
|  | def test_input_type(self, input_cls): | 
|  | """Test FSDP with input being a list or a dict, only single GPU.""" | 
|  |  | 
|  | class Model(Module): | 
|  | def __init__(self): | 
|  | super().__init__() | 
|  | self.layer = Linear(4, 4) | 
|  |  | 
|  | def forward(self, input): | 
|  | if isinstance(input, list): | 
|  | input = input[0] | 
|  | else: | 
|  | assert isinstance(input, dict), input | 
|  | input = input["in"] | 
|  | return self.layer(input) | 
|  |  | 
|  | model = FSDP(Model()).cuda() | 
|  | optim = SGD(model.parameters(), lr=0.1) | 
|  |  | 
|  | for _ in range(5): | 
|  | in_data = torch.rand(64, 4).cuda() | 
|  | in_data.requires_grad = True | 
|  | if input_cls is list: | 
|  | in_data = [in_data] | 
|  | else: | 
|  | self.assertTrue(input_cls is dict) | 
|  | in_data = {"in": in_data} | 
|  |  | 
|  | out = model(in_data) | 
|  | out.sum().backward() | 
|  | optim.step() | 
|  | optim.zero_grad() | 
|  |  | 
|  |  | 
|  | instantiate_parametrized_tests(TestInput) | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |