Add accuracy tests for traced optimizers (#97577)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97577
Approved by: https://github.com/yifuwang
diff --git a/test/distributed/_spmd/test_tracing.py b/test/distributed/_spmd/test_tracing.py
index 000d2e8..6c8f0e6 100644
--- a/test/distributed/_spmd/test_tracing.py
+++ b/test/distributed/_spmd/test_tracing.py
@@ -517,6 +517,29 @@
             # Should we match that behavior?
             self.assertEqual(g1 / self.world_size, p2.grad)
 
+    def _test_optimizer(self, mod, ddp_mod, opt, ddp_opt, inp, train_step):
+        ddp_inp = deepcopy(inp)
+
+        # materialize optimizer states
+        mod(inp).sum().backward()
+        opt.step()
+        opt.zero_grad()
+
+        ddp_mod(ddp_inp).sum().backward()
+        ddp_opt.step()
+        ddp_opt.zero_grad()
+
+        # test parameter parity
+        # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce, so
+        # only testing local model for now
+        train_step(mod, opt, inp)
+
+        ddp_mod(ddp_inp).sum().backward()
+        ddp_opt.step()
+
+        for p1, p2 in zip(mod.parameters(), ddp_mod.parameters()):
+            self.assertEqual(p1, p2)
+
     @skip_if_lt_x_gpu(2)
     @with_comms
     def test_sgd(self):
@@ -526,16 +549,18 @@
             opt.step()
 
         rank = torch.distributed.get_rank()
-        mod = nn.Linear(10, 10).cuda(rank)
+        # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize
+        # module parameters.
+        torch.manual_seed(0)
+        # FIXME(@mrshenli): gradients for bias is missing
+        mod = nn.Linear(10, 10, bias=False).cuda(rank)
         # FIXME(@mrshenli): we have to enable foreach to get better perf
         opt = torch.optim.SGD(mod.parameters(), lr=0.01, foreach=False)
-        inp = torch.zeros(2, 10).cuda(rank)
+        inp = torch.randn(2, 10).cuda(rank)
 
-        mod(inp).sum().backward()
-        opt.step()
-
-        # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce
-        train_step(mod, opt, inp)
+        ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
+        ddp_opt = torch.optim.SGD(ddp_mod.parameters(), lr=0.01, foreach=False)
+        self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)
 
     @skip_if_lt_x_gpu(2)
     @with_comms
@@ -546,17 +571,20 @@
             opt.step()
 
         rank = torch.distributed.get_rank()
-        mod = nn.Linear(10, 10).cuda(rank)
+        # FIXME(@mrshenli): remove manual seed once dist.compile can synchronize
+        # module parameters.
+        torch.manual_seed(0)
+        # FIXME(@mrshenli): gradients for bias is missing
+        mod = nn.Linear(10, 10, bias=False).cuda(rank)
+        # FIXME(@mrshenli): we have to enable foreach to get better perf
         opt = torch.optim.Adam(
             mod.parameters(), lr=0.01, foreach=False, capturable=True
         )
-        inp = torch.zeros(2, 10).cuda(rank)
+        inp = torch.randn(2, 10).cuda(rank)
 
-        mod(inp).sum().backward()
-        opt.step()
-
-        # FIXME(@mrshenli): inplace op + DTensor does not trigger allreduce
-        train_step(mod, opt, inp)
+        ddp_mod = DDP(deepcopy(mod), device_ids=[rank])
+        ddp_opt = torch.optim.Adam(ddp_mod.parameters(), lr=0.01, foreach=False)
+        self._test_optimizer(mod, ddp_mod, opt, ddp_opt, inp, train_step)
 
     @skip_if_lt_x_gpu(2)
     @with_comms
@@ -606,7 +634,10 @@
         train_step(mod, opt, inp)
 
         # checking transforms are indeed invoked.
-        self.assertEqual(transform_targets, [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default])
+        self.assertEqual(
+            transform_targets,
+            [torch.ops.dummy.ddm.default, torch.ops.dummy.ddm_backward.default],
+        )
 
 
 if __name__ == "__main__":