[dtensor][debug] allow visualize_sharding to print header (#121179)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121179
Approved by: https://github.com/wanchaol
diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py
index 6bf4c04..d07a239 100644
--- a/torch/distributed/_tensor/debug/visualize_sharding.py
+++ b/torch/distributed/_tensor/debug/visualize_sharding.py
@@ -130,7 +130,7 @@
         return tuple(local_shape), tuple(global_offset)
 
 
-def visualize_sharding(dtensor):
+def visualize_sharding(dtensor, header=""):
     """
     Visualizes sharding in 1D-2D dtensors
     Requires tabulate, install with `pip install tabulate`
@@ -154,4 +154,5 @@
     # Convert offsets to blocks with row_ranges for tabulate
     blocks = _convert_offset_to_ranges(all_offsets)
     if device_mesh.get_rank() == 0:
+        print(header)
         print(_create_table(blocks))