| # Owner(s): ["oncall: distributed"] |
| |
| import sys |
| |
| import torch |
| from torch import distributed as dist |
| from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
| from torch.nn import Linear |
| from torch.optim import SGD |
| from torch.testing._internal.common_distributed import skip_if_lt_x_gpu |
| from torch.testing._internal.common_fsdp import FSDPTest |
| from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN |
| |
| if not dist.is_available(): |
| print("Distributed not available, skipping tests", file=sys.stderr) |
| sys.exit(0) |
| |
| if TEST_WITH_DEV_DBG_ASAN: |
| print( |
| "Skip dev-asan as torch + multiprocessing spawn have known issues", |
| file=sys.stderr, |
| ) |
| sys.exit(0) |
| |
| |
| class TestUnevenParamShard(FSDPTest): |
| def _get_ref_results(self, model, input, my_lr): |
| with torch.no_grad(): |
| # Compute one iteration local output. |
| weight = model.weight.T.clone().to(self.rank) |
| v = torch.Tensor(input[self.rank]).to(self.rank) |
| ref_forward_output_my_rank = torch.matmul(v, weight) |
| # Compute one iteration global weight update. |
| v = torch.Tensor(input[: self.world_size]).to(self.rank) |
| grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size) |
| ref_weight_out = weight - grad.T * my_lr |
| |
| return ref_forward_output_my_rank, ref_weight_out |
| |
| @skip_if_lt_x_gpu(2) |
| def test_one_iteration(self): |
| """Test FSDP with uneven divide of parameter shards.""" |
| model = Linear(3, 3, bias=False) |
| input = torch.rand(8, 3) |
| my_lr = 0.1 |
| |
| ref_forward_output_my_rank, ref_weight_out = self._get_ref_results( |
| model, input, my_lr |
| ) |
| |
| model.to(self.rank) |
| model = FSDP(model) |
| optim = SGD(model.parameters(), lr=my_lr) |
| self.assertTrue(len(input) >= self.world_size) |
| in_data = torch.Tensor(input[self.rank]).to(self.rank) |
| out = model(in_data) |
| out.float().sum().backward() |
| optim.step() |
| optim.zero_grad() |
| |
| with model.summon_full_params(model): |
| weight_out = model.module.weight.T.clone() |
| self.assertEqual(ref_forward_output_my_rank, out) |
| self.assertEqual(ref_weight_out, weight_out) |
| |
| |
| if __name__ == "__main__": |
| run_tests() |