[dtensor][test] test case suite for comm_mode features (#128729)
**Summary**
Currently, there is only an example file for comm_mode and its features. I have created test cases that mirror the examples while the more complicated test cases also ensure that comm_mode resets all variables when used multiple times in the same function. This test case suite will also help developers ensure that new code they add to comm_mode does not affect correctness of old features.
#128536
**Test Plan**
pytest test/distributed/_tensor/debug/test_comm_mode_features.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128729
Approved by: https://github.com/XilunWu
diff --git a/test/distributed/_tensor/debug/test_comm_mode_features.py b/test/distributed/_tensor/debug/test_comm_mode_features.py
new file mode 100644
index 0000000..a8b99d7
--- /dev/null
+++ b/test/distributed/_tensor/debug/test_comm_mode_features.py
@@ -0,0 +1,358 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# Owner(s): ["oncall: distributed"]
+
+from typing import Any, Dict
+
+import torch
+from torch.distributed._tensor import DeviceMesh
+from torch.distributed._tensor.api import distribute_tensor, DTensor
+from torch.distributed._tensor.debug import CommDebugMode
+from torch.distributed.tensor.parallel import (
+ ColwiseParallel,
+ parallelize_module,
+ RowwiseParallel,
+)
+from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed._tensor.common_dtensor import (
+ DTensorTestBase,
+ MLPModule,
+ MLPStacked,
+ ModelArgs,
+ NUM_DEVICES,
+ skip_unless_torch_gpu,
+ Transformer,
+ with_comms,
+)
+
+
+c10d_functional = torch.ops.c10d_functional
+
+
+class TestCommModeFeatures(DTensorTestBase):
+ # checks if parameter / sharding info is the same as ground truth
+ def check_same_set_of_keys(self, dict1, dict2):
+ """
+ Used to ensure the comm_mode parameter/sharding dictionaries contain the same information produced by the
+ ground truth
+ """
+ dict1_keys = []
+ dict2_keys = []
+
+ for key in dict1:
+ for nested_key in dict1[key]:
+ dict1_keys.append((key, nested_key))
+
+ for key in dict2:
+ for nested_key in dict2[key]:
+ dict2_keys.append((key, nested_key))
+
+ self.assertEqual(len(dict1_keys), len(dict2_keys))
+
+ for i in range(len(dict1_keys)):
+ self.assertEqual(dict1_keys[i], dict2_keys[i])
+
+ # generates the ground truth parameter and sharding info
+ def ground_truth(self, model):
+ """
+ Used to generate the ground-truth parameter and sharding info for a given distributed model to
+ verify comm_mode correctness
+ """
+ module_parameters_dict: Dict[str, Any] = {}
+ module_sharding_dict: Dict[str, Any] = {}
+
+ for name, parameters in model.named_parameters():
+ # splits name into module name to create FQN and parameter name
+ module_name = model.__class__.__name__ + "." + name.rsplit(".", 1)[0]
+ parameter_name = name.rsplit(".", 1)[1]
+
+ if module_name not in module_parameters_dict:
+ module_parameters_dict[module_name] = {}
+
+ module_parameters_dict[module_name][parameter_name] = parameters.data
+
+ if isinstance(parameters.data, DTensor):
+ key_name = module_name + "." + parameter_name
+ module_sharding_dict[key_name] = parameters.data.placements
+
+ return module_parameters_dict, module_sharding_dict
+
+ @with_comms
+ def test_MLP_distributed_sharding_display(self):
+ """
+ tests parameters and sharding on a module level
+ """
+ device_mesh = DeviceMesh(
+ self.device_type,
+ torch.arange(0, NUM_DEVICES),
+ )
+
+ inp_size = [8, 10]
+ torch.manual_seed(0)
+ inp = torch.rand(*inp_size, device=self.device_type)
+ model = MLPModule(self.device_type)
+
+ parallelize_plan = {
+ "net1": ColwiseParallel(),
+ "net2": RowwiseParallel(),
+ }
+
+ model = parallelize_module(model, device_mesh, parallelize_plan)
+
+ comm_mode = CommDebugMode()
+
+ with comm_mode:
+ output_tp = model(inp)
+ output_tp.sum().backward()
+
+ module_parameters_dict, module_sharding_dict = self.ground_truth(model)
+
+ # checks if parameter / sharding info is the same as ground truth
+ self.check_same_set_of_keys(
+ module_parameters_dict, comm_mode.get_parameter_info()
+ )
+ self.check_same_set_of_keys(module_sharding_dict, comm_mode.get_sharding_info())
+
+ @with_comms
+ def test_MLPStacked_distributed_sharding_display(self):
+ """
+ tests model with nested modules and makes sure comm_mode correctly resets parameter and sharding information
+ """
+
+ device_mesh = DeviceMesh(
+ self.device_type,
+ torch.arange(0, NUM_DEVICES),
+ )
+
+ inp_size = [8, 10]
+ torch.manual_seed(0)
+ inp = torch.rand(*inp_size, device=self.device_type)
+ model = MLPModule(self.device_type)
+
+ parallelize_plan = {
+ "net1": ColwiseParallel(),
+ "net2": RowwiseParallel(),
+ }
+
+ model = parallelize_module(model, device_mesh, parallelize_plan)
+
+ comm_mode = CommDebugMode()
+
+ with comm_mode:
+ output_tp = model(inp)
+ output_tp.sum().backward()
+
+ model2 = MLPStacked(self.device_type)
+
+ parallelize_plan = {
+ "MLPStacked.layers.0.net1": ColwiseParallel(),
+ "MLPStacked.layers.0.net2": RowwiseParallel(),
+ "MLPStacked.layers.1.net1": ColwiseParallel(),
+ "MLPStacked.layers.1.net2": RowwiseParallel(),
+ }
+
+ model2 = parallelize_module(model2, device_mesh, parallelize_plan)
+
+ with comm_mode:
+ # ensures that comm_mode is resetting properly
+ self.assertEqual(comm_mode.get_parameter_info(), {})
+ self.assertEqual(comm_mode.get_sharding_info(), {})
+
+ output_tp = model2(inp)
+
+ module_parameters_dict, module_sharding_dict = self.ground_truth(model2)
+
+ self.check_same_set_of_keys(
+ module_parameters_dict, comm_mode.get_parameter_info()
+ )
+ self.check_same_set_of_keys(module_sharding_dict, comm_mode.get_sharding_info())
+ self.assertEqual(len(comm_mode.get_sharding_info()), 8)
+
+ @with_comms
+ def test_MLP_module_tracing(self):
+ """
+ tests module-level tracing for MLP module
+ """
+
+ device_mesh = DeviceMesh(
+ self.device_type,
+ torch.arange(0, NUM_DEVICES),
+ )
+ inp_size = [8, 10]
+ torch.manual_seed(0)
+ inp = torch.rand(*inp_size, device=self.device_type)
+ model = MLPModule(self.device_type)
+
+ parallelize_plan = {
+ "net1": ColwiseParallel(),
+ "net2": RowwiseParallel(),
+ }
+
+ model = parallelize_module(model, device_mesh, parallelize_plan)
+
+ comm_mode = CommDebugMode()
+
+ with comm_mode:
+ output_tp = model(inp)
+ output_tp.sum().backward()
+
+ # checks to see if all sub-modules make it into the module_depth_dictionary
+ self.assertEqual(len(comm_mode.advanced_module_tracker.module_depth_dict), 5)
+
+ # checks to see if all collectives were correctly traced at the module-level
+
+ self.assertEqual(
+ comm_mode.comm_module_counts["Global"][c10d_functional.all_reduce], 1
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["MLPModule"][c10d_functional.all_reduce], 1
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["MLPModule.net2"][c10d_functional.all_reduce],
+ 1,
+ )
+
+ @skip_unless_torch_gpu
+ @with_comms
+ def test_transformer_module_tracing(self, is_seq_parallel=False):
+ """
+ tests module-level tracing for more complicated transformer module and
+ ensures that comm_module depth and tracing dictionaries correctly reset
+ """
+ device_mesh = DeviceMesh(
+ self.device_type,
+ torch.arange(0, NUM_DEVICES),
+ )
+ inp_size = [8, 10]
+ torch.manual_seed(0)
+ inp = torch.rand(*inp_size, device=self.device_type)
+ model = MLPModule(self.device_type)
+
+ parallelize_plan = {
+ "net1": ColwiseParallel(),
+ "net2": RowwiseParallel(),
+ }
+
+ model = parallelize_module(model, device_mesh, parallelize_plan)
+
+ comm_mode = CommDebugMode()
+ with comm_mode:
+ self.assertEqual(
+ len(comm_mode.advanced_module_tracker.module_depth_dict), 1
+ )
+ self.assertEqual(comm_mode.comm_module_counts, {})
+ output_tp = model(inp)
+
+ model_args = ModelArgs(dropout_p=0.0)
+ model2 = Transformer(model_args).to(device=self.device_type)
+ model2 = Transformer.parallelize(model2, device_mesh, is_seq_parallel)
+
+ inp_size = [8, 8]
+
+ inp = torch.randint(model_args.vocab_size, inp_size, device=self.device_type)
+ inp = distribute_tensor(inp, device_mesh=device_mesh)
+
+ comm_mode = CommDebugMode()
+ with comm_mode:
+ output = model2(inp)
+
+ # checks to see if all collectives were correctly traced at the module-level
+ self.assertEqual(
+ comm_mode.comm_module_counts["Global"][c10d_functional.all_reduce], 6
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Global"][
+ c10d_functional.all_gather_into_tensor
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer"][c10d_functional.all_reduce], 6
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer"][
+ c10d_functional.all_gather_into_tensor
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.tok_embeddings"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.pos_embeddings"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.0"][
+ c10d_functional.all_reduce
+ ],
+ 2,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.0.attention"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.0.attention.wo"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.0.feed_forward"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.0.feed_forward.w2"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.1"][
+ c10d_functional.all_reduce
+ ],
+ 2,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.1.attention"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.1.attention.wo"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.1.feed_forward"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.layers.1.feed_forward.w2"][
+ c10d_functional.all_reduce
+ ],
+ 1,
+ )
+ self.assertEqual(
+ comm_mode.comm_module_counts["Transformer.output"][
+ c10d_functional.all_gather_into_tensor
+ ],
+ 1,
+ )
+
+
+if __name__ == "__main__":
+ run_tests()