ProcessGroupWrapper support custom backend (#124447)
Fixes #ISSUE_NUMBER
In current code, ProcessGroupWrapper works only for `GLOO, NCCL, UCC` when `TORCH_DISTRIBUTED_DEBUG=DETAIL`.
I read the ProcessGroupWrapper codeļ¼find that communication_op in ProcessGroupWrapper is just communication_op in origin_backend + runCollectiveChecks in gloo, like allreduce:
https://github.com/pytorch/pytorch/blob/82e0153487c2cd1abc92598963be5b57ab1948d4/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp#L406-L411
`runCollectiveChecks` is used to `collective finger print` for tensors and run gloo's `monitoredBarrier`.
https://github.com/pytorch/pytorch/blob/82e0153487c2cd1abc92598963be5b57ab1948d4/torch/csrc/distributed/c10d/ProcessGroupWrapper.cpp#L586-L590
I dont know why ProcessGroupWrapper doesn't work for all backend, but I think custom backend can support it.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124447
Approved by: https://github.com/kwen2501
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index 74f2ed5..c006fbc 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1602,7 +1602,7 @@
break
# Process group wrapper initialization for supported PGs when TORCH_DISTRIBUTED_DEBUG is set
- if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC]:
+ if backend_str in [Backend.GLOO, Backend.NCCL, Backend.UCC] or backend_str.upper() in Backend._plugins:
# In debug mode and if GLOO is available, wrap in a wrapper PG that
# enables enhanced collective checking for debuggability.
if get_debug_level() == DebugLevel.DETAIL: