Add tests for replicate multiple modules (#89099)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89099
Approved by: https://github.com/zhaojuanmao
diff --git a/test/distributed/_composable/test_replicate.py b/test/distributed/_composable/test_replicate.py
index 831ccc3..3e8bf44 100644
--- a/test/distributed/_composable/test_replicate.py
+++ b/test/distributed/_composable/test_replicate.py
@@ -39,13 +39,7 @@
except OSError:
pass
- def _prepare_module(self, global_batch_size):
- model = Net()
- input = torch.randn(global_batch_size, 2)
- target = torch.randn(global_batch_size, 4)
- return model, input, target
-
- def test_replicate(self):
+ def _compare_module(self, mod, replicate_mod):
dist.init_process_group(
backend="gloo",
rank=self.rank,
@@ -55,8 +49,8 @@
local_batch_size = 1
global_batch_size = self.world_size * local_batch_size
- model, input, target = self._prepare_module(global_batch_size)
- replicate_model = mark_root_module(replicate(deepcopy(model)))
+ input = torch.randn(global_batch_size, 2)
+ target = torch.randn(global_batch_size, 4)
def step_model(model, input, target):
model.train()
@@ -69,9 +63,9 @@
param.grad = None
for iteration in range(2):
- step_model(model, input, target)
+ step_model(mod, input, target)
step_model(
- replicate_model,
+ replicate_mod,
input[
self.rank
* local_batch_size : (self.rank + 1)
@@ -85,16 +79,29 @@
)
self.assertEqual(
- len(list(model.parameters())),
- len(list(replicate_model.parameters())),
+ len(list(mod.parameters())),
+ len(list(replicate_mod.parameters())),
)
- for i, j in zip(model.parameters(), replicate_model.parameters()):
+ for i, j in zip(mod.parameters(), replicate_mod.parameters()):
self.assertEqual(i, j, rtol=1.3e-06, atol=5e-5)
# Shuffle the input so that DDP input is different
torch.manual_seed(iteration)
input = input[torch.randperm(global_batch_size)]
+ def test_replicate_single_module(self):
+ model = Net()
+ replicate_model = mark_root_module(replicate(deepcopy(model)))
+ self._compare_module(model, replicate_model)
+
+ def test_replicate_multi_module(self):
+ model = Net()
+ replicate_model = mark_root_module(deepcopy(model))
+ replicate(replicate_model.fc1)
+ replicate(replicate_model.fc2)
+ replicate(replicate_model.fc3)
+ self._compare_module(model, replicate_model)
+
if __name__ == "__main__":
run_tests()