Add accuracy tests for traced optimizers (#97577)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97577
Approved by: https://github.com/yifuwang
diff --git a/test/distributed/_spmd/test_tracing.py b/test/distributed/_spmd/test_tracing.py
index 000d2e8..6c8f0e6 100644
--- a/test/distributed/_spmd/test_tracing.py
+++ b/test/distributed/_spmd/test_tracing.py
@@ -517,6 +517,29 @@
# Should we match that behavior?
self.assertEqual(g1 / self.world_size, p2.grad)
+ def _test_optimizer(self, mod, ddp_mod, opt, ddp_opt, inp, train_step):
+ ddp_inp = deepcopy(inp)
+
+ # materialize optimizer states
+ mod(inp).sum().backward()
+ opt.step()
+ opt.zero_grad()
+
+ ddp_mod(ddp_inp).sum().backward()
+ ddp_opt.step()
+ ddp_opt.zero_grad()
+
+ # test parameter parity
+ # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce, so
+ # only testing local model for now
+ train_step(mod, opt, inp)
+
+ ddp_mod(ddp_inp).sum().backward()
+ ddp_opt.step()
+
+ for p1, p2 in zip(mod.parameters(), ddp_mod.parameters()):
+ self.assertEqual(p1, p2)
+
@skip_if_lt_x_gpu(2)
@with_comms
def test_sgd(self):
@@ -526,16 +549,18 @@
opt.step()
rank = torch.distributed.get_rank()
- mod = nn.Linear(10, 10).cuda(rank)
+ # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize
+ # module parameters.
+ torch.manual_seed(0)
+ # FIXME(@mrshenli): gradients for bias is missing
+ mod = nn.Linear(10, 10, bias=False).cuda(rank)
# FIXME(@mrshenli): we have to enable foreach to get better perf
opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False)
- inp = torch.zeros(2, 10).cuda(rank)
+ inp = torch.randn(2, 10).cuda(rank)
- mod(inp).sum().backward()
- opt.step()
-
- # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce
- train_step(mod, opt, inp)
+ ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
+ ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=False)
+ self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)
@skip_if_lt_x_gpu(2)
@with_comms
@@ -546,17 +571,20 @@
opt.step()
rank = torch.distributed.get_rank()
- mod = nn.Linear(10, 10).cuda(rank)
+ # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize
+ # module parameters.
+ torch.manual_seed(0)
+ # FIXME(@mrshenli): gradients for bias is missing
+ mod = nn.Linear(10, 10, bias=False).cuda(rank)
+ # FIXME(@mrshenli): we have to enable foreach to get better perf
opt = torch.optim.Adam(
mod.parameters(), lr=0.01, foreach=False, capturable=True
)
- inp = torch.zeros(2, 10).cuda(rank)
+ inp = torch.randn(2, 10).cuda(rank)
- mod(inp).sum().backward()
- opt.step()
-
- # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce
- train_step(mod, opt, inp)
+ ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
+ ddp_opt = torch.optim.Adam(ddp_mod.parameters(), lr=0.01, foreach=False)
+ self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)
@skip_if_lt_x_gpu(2)
@with_comms
@@ -606,7 +634,10 @@
train_step(mod, opt, inp)
# checking transforms are indeed invoked.
- self.assertEqual(transform_targets, [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default])
+ self.assertEqual(
+ transform_targets,
+ [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default],
+ )
if __name__ == "__main__":