commit | d707724ac9ce184af46bc4b5c1d923e15b798bdd | [log] [tgz] |
---|---|---|
author | wz337 <wz337@cornell.edu> | Thu Aug 24 01:28:19 2023 +0000 |
committer | PyTorch MergeBot <pytorchmergebot@users.noreply.github.com> | Thu Aug 24 01:28:22 2023 +0000 |
tree | a2eb72ead31e1ec0aa0e3f8e6178ccdc9575a57f | |
parent | 26ae48832e1b184c9eeabe8f6c3e36661f0bb00e [diff] |
[DeviceMesh] init_device_mesh dosctring update to include one d mesh initialization (#107805) As title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107805 Approved by: https://github.com/fduwjj, https://github.com/wanchaol
diff --git a/torch/distributed/_tensor/device_mesh.py b/torch/distributed/_tensor/device_mesh.py index 276d79d..2ac36bb 100644 --- a/torch/distributed/_tensor/device_mesh.py +++ b/torch/distributed/_tensor/device_mesh.py
@@ -381,6 +381,7 @@ >>> # xdoctest: +SKIP >>> from torch.distributed._tensor.device_mesh import init_device_mesh >>> + >>> one_d_mesh = init_device_mesh("cuda", mesh_shape=(8,)) >>> two_d_mesh = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) """ if mesh_dim_names is not None and len(mesh_shape) != len(mesh_dim_names):