Add tests for replicate multiple modules (#89099)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099
Approved by: https://github.com/zhaojuanmao
diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py
index 831ccc3..3e8bf44 100644
--- a/test/distributed/_composable/test_replicate.py
+++ b/test/distributed/_composable/test_replicate.py
@@ -39,13 +39,7 @@
         except OSError:
             pass
 
-    def _prepare_module(self, global_batch_size):
-        model = Net()
-        input = torch.randn(global_batch_size, 2)
-        target = torch.randn(global_batch_size, 4)
-        return model, input, target
-
-    def test_replicate(self):
+    def _compare_module(self, mod, replicate_mod):
         dist.init_process_group(
             backend="gloo",
             rank=self.rank,
@@ -55,8 +49,8 @@
 
         local_batch_size = 1
         global_batch_size = self.world_size * local_batch_size
-        model, input, target = self._prepare_module(global_batch_size)
-        replicate_model = mark_root_module(replicate(deepcopy(model)))
+        input = torch.randn(global_batch_size, 2)
+        target = torch.randn(global_batch_size, 4)
 
         def step_model(model, input, target):
             model.train()
@@ -69,9 +63,9 @@
                 param.grad = None
 
         for iteration in range(2):
-            step_model(model, input, target)
+            step_model(mod, input, target)
             step_model(
-                replicate_model,
+                replicate_mod,
                 input[
                     self.rank
                     * local_batch_size : (self.rank + 1)
@@ -85,16 +79,29 @@
             )
 
             self.assertEqual(
-                len(list(model.parameters())),
-                len(list(replicate_model.parameters())),
+                len(list(mod.parameters())),
+                len(list(replicate_mod.parameters())),
             )
-            for i, j in zip(model.parameters(), replicate_model.parameters()):
+            for i, j in zip(mod.parameters(), replicate_mod.parameters()):
                 self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
 
             # Shuffle the input so that DDP input is different
             torch.manual_seed(iteration)
             input = input[torch.randperm(global_batch_size)]
 
+    def test_replicate_single_module(self):
+        model = Net()
+        replicate_model = mark_root_module(replicate(deepcopy(model)))
+        self._compare_module(model, replicate_model)
+
+    def test_replicate_multi_module(self):
+        model = Net()
+        replicate_model = mark_root_module(deepcopy(model))
+        replicate(replicate_model.fc1)
+        replicate(replicate_model.fc2)
+        replicate(replicate_model.fc3)
+        self._compare_module(model, replicate_model)
+
 
 if __name__ == "__main__":
     run_tests()