blob: 6ffeb279b617b137d1faa77b1b801a10e601ae0d [file] [log] [blame]
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.nn import Linear
from torch.optim import SGD
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
if not dist.is_available():
print("Distributed not available, skipping tests", file=sys.stderr)
sys.exit(0)
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
file=sys.stderr,
)
sys.exit(0)
class TestUnevenParamShard(FSDPTest):
def _get_ref_results(self, model, input, my_lr):
with torch.no_grad():
# Compute one iteration local output.
weight = model.weight.T.clone().to(self.rank)
v = torch.Tensor(input[self.rank]).to(self.rank)
ref_forward_output_my_rank = torch.matmul(v, weight)
# Compute one iteration global weight update.
v = torch.Tensor(input[: self.world_size]).to(self.rank)
grad = v.float().sum(0).repeat(weight.shape[0], 1).div(self.world_size)
ref_weight_out = weight - grad.T * my_lr
return ref_forward_output_my_rank, ref_weight_out
@skip_if_lt_x_gpu(2)
def test_one_iteration(self):
"""Test FSDP with uneven divide of parameter shards."""
model = Linear(3, 3, bias=False)
input = torch.rand(8, 3)
my_lr = 0.1
ref_forward_output_my_rank, ref_weight_out = self._get_ref_results(
model, input, my_lr
)
model.to(self.rank)
model = FSDP(model)
optim = SGD(model.parameters(), lr=my_lr)
self.assertTrue(len(input) >= self.world_size)
in_data = torch.Tensor(input[self.rank]).to(self.rank)
out = model(in_data)
out.float().sum().backward()
optim.step()
optim.zero_grad()
with model.summon_full_params(model):
weight_out = model.module.weight.T.clone()
self.assertEqual(ref_forward_output_my_rank, out)
self.assertEqual(ref_weight_out, weight_out)
if __name__ == "__main__":
run_tests()