[UCC] Fix input tensor in scatter (#112246)
Input tensor is valid only for root rank. Fixes https://github.com/openucx/ucc/issues/859
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112246
Approved by: https://github.com/Aidyn-A, https://github.com/Fuzzkatt, https://github.com/kwen2501
diff --git a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp
index 35a15d7..cd6c4aa 100644
--- a/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp
+++ b/torch/csrc/distributed/c10d/ProcessGroupUCC.cpp
@@ -1512,7 +1512,7 @@
coll,
std::unique_ptr<WorkData>(data),
tensor.device(),
- inputTensors[0],
+ (getRank() == opts.rootRank) ? inputTensors[0] : outputTensors,
outputTensors,
"ucc:scatter");
}