commit | 9cc0f23e5ccc122bcdeeb608ddcbf392c4e44802 | [log] [tgz] |
---|---|---|
author | Xilun Wu <12968408+XilunWu@users.noreply.github.com> | Wed Mar 06 13:48:26 2024 -0800 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Thu Mar 07 04:50:06 2024 +0000 |
tree | 78faeaac32ca398d5624279d31ac46c9b8d75426 | |
parent | a2854ae904c55ebf19031ca6e883e5dcbb6a4a9a [diff] |
[dtensor][debug] allow visualize_sharding to print header (#121179) Pull Request resolved: https://github.com/pytorch/pytorch/pull/121179 Approved by: https://github.com/wanchaol
diff --git a/torch/distributed/_tensor/debug/visualize_sharding.py b/torch/distributed/_tensor/debug/visualize_sharding.py index 6bf4c04..d07a239 100644 --- a/torch/distributed/_tensor/debug/visualize_sharding.py +++ b/torch/distributed/_tensor/debug/visualize_sharding.py
@@ -130,7 +130,7 @@ return tuple(local_shape), tuple(global_offset) -def visualize_sharding(dtensor): +def visualize_sharding(dtensor, header=""): """ Visualizes sharding in 1D-2D dtensors Requires tabulate, install with `pip install tabulate` @@ -154,4 +154,5 @@ # Convert offsets to blocks with row_ranges for tabulate blocks = _convert_offset_to_ranges(all_offsets) if device_mesh.get_rank() == 0: + print(header) print(_create_table(blocks))