Enable TensorPipe CUDA fallback channel (#50675)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50675
Test Plan: Imported from OSS
Reviewed By: beauby
Differential Revision: D25941963
Pulled By: mrshenli
fbshipit-source-id: 205786d7366f36d659a3a3374081a458cfcb4dd1
diff --git a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
index 09ed5c5..4f56c91 100644
--- a/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
+++ b/torch/csrc/distributed/rpc/tensorpipe_agent.cpp
@@ -89,6 +89,7 @@
#ifdef USE_CUDA_NOT_ROCM
constexpr int64_t kCudaXthChannelPriority = 400;
+constexpr int64_t kCudaBasicChannelPriority = 100;
#endif
std::unique_ptr<TransportRegistration> makeUvTransport() {
@@ -222,6 +223,20 @@
cuda_xth,
makeCudaXthChannel);
+std::unique_ptr<CudaChannelRegistration> makeCudaBasicChannel() {
+ auto context = std::make_shared<tensorpipe::channel::cuda_basic::Context>(
+ std::make_shared<tensorpipe::channel::basic::Context>());
+ return std::make_unique<CudaChannelRegistration>(
+ CudaChannelRegistration{std::move(context), kCudaBasicChannelPriority});
+}
+
+// The cuda_basic is the fallback channel for GPU-to-GPU comm
+// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
+C10_REGISTER_CREATOR(
+ TensorPipeCudaChannelRegistry,
+ cuda_basic,
+ makeCudaBasicChannel);
+
#endif
} // namespace