[TP][Tests] Replace assertEqual with deepcopy (#123218)
There were a lot of manual `assertEqual`'s in the tests to make sure `model_tp` was created the same as `model`.
`model_tp = copy.deepcopy(model)` should help us rest assured.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123218
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/tensor/parallel/test_parallelize_api.py b/test/distributed/tensor/parallel/test_parallelize_api.py
index 019efd9..ed5a736 100644
--- a/test/distributed/tensor/parallel/test_parallelize_api.py
+++ b/test/distributed/tensor/parallel/test_parallelize_api.py
@@ -101,13 +101,7 @@
def test_parallelize_mlp_with_module_api(self):
inp_size = [12, 10]
model = MLPModule(self.device_type)
- model_tp = MLPModule(self.device_type)
-
- # Ensure model are initialized the same way.
- self.assertEqual(model.net1.weight, model_tp.net1.weight)
- self.assertEqual(model.net1.bias, model_tp.net1.bias)
- self.assertEqual(model.net2.weight, model_tp.net2.weight)
- self.assertEqual(model.net2.bias, model_tp.net2.bias)
+ model_tp = deepcopy(model)
# Parallelize module.
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@@ -127,23 +121,7 @@
model = torch.nn.Sequential(
OrderedDict([("dummy_encoder", MLPModule(self.device_type))])
)
- model_tp = torch.nn.Sequential(
- OrderedDict([("dummy_encoder", MLPModule(self.device_type))])
- )
-
- # Ensure model are initialized the same way.
- self.assertEqual(
- model.dummy_encoder.net1.weight, model_tp.dummy_encoder.net1.weight
- )
- self.assertEqual(
- model.dummy_encoder.net1.bias, model_tp.dummy_encoder.net1.bias
- )
- self.assertEqual(
- model.dummy_encoder.net2.weight, model_tp.dummy_encoder.net2.weight
- )
- self.assertEqual(
- model.dummy_encoder.net2.bias, model_tp.dummy_encoder.net2.bias
- )
+ model_tp = deepcopy(model)
# Parallelize module.
device_mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
@@ -165,8 +143,7 @@
torch.manual_seed(5)
model = torch.nn.Linear(16, 10, device=self.device_type)
- torch.manual_seed(5)
- model_tp = torch.nn.Linear(16, 10, device=self.device_type)
+ model_tp = deepcopy(model)
# parallelize model_tp
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))
@@ -184,8 +161,7 @@
torch.manual_seed(5)
model = torch.nn.Linear(10, 16, device=self.device_type)
- torch.manual_seed(5)
- model_tp = torch.nn.Linear(10, 16, device=self.device_type)
+ model_tp = deepcopy(model)
# parallelize model_tp
device_mesh = DeviceMesh(self.device_type, list(range(self.world_size)))