Moving python allgather_coalesced impl from Py to C. (#29059)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29059
This is a resubmit of reverted diff D18209289 ( PR #28857 ).

Test Plan:
buck test caffe2/test:c10d
buck test caffe2/test:distributed_gloo

Reviewed By: pietern

Differential Revision: D18277097

fbshipit-source-id: aecfd7206d70829f0cac66182bf02fccee410fed
diff --git a/test/test_c10d.py b/test/test_c10d.py
index 7d8d85b..d4b6ac6 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -1309,31 +1309,37 @@
     def test_allgather_coalesced_checks(self):
         store = c10d.FileStore(self.file_name, self.world_size)
         pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
-        dummy_input = [torch.Tensor([1])]
+        dummy_input = [torch.zeros([1], dtype=torch.float32)]
         dummy_output_lists = [
-            [torch.Tensor([-1])] for _ in range(self.world_size)
+            [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size)
         ]
-        with self.assertRaisesRegex(RuntimeError,
-                                    "all_gather_coalesced does not support "
-                                    "async mode yet."):
-            c10d.all_gather_coalesced(
-                dummy_output_lists, dummy_input, pg, async_op=True)
 
         # One of output tensors does not match input list.
-        dummy_output_lists[0] = [torch.Tensor(0)]
-        with self.assertRaisesRegex(RuntimeError, "Shape tensor mismatch"):
+        dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)]
+        with self.assertRaisesRegex(ValueError,
+                                    "invalid size of output tensor at index 0"):
             c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
 
-        # Output is not a list of lists.
-        dummy_output_lists = [torch.Tensor(0)]
-        with self.assertRaisesRegex(RuntimeError, "Invalid function argument"):
+        # One of output tensors does not match input list.
+        dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)]
+        with self.assertRaisesRegex(ValueError,
+                                    "invalid tensor type at index 0"):
             c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
 
         # Output lists have too many elements
         dummy_output_lists = [
-            [torch.Tensor([-1])] for _ in range(self.world_size + 1)
+            [
+                torch.zeros([1], dtype=torch.float32)
+            ] for _ in range(self.world_size + 1)
         ]
-        with self.assertRaisesRegex(ValueError, "invalid output tensor"):
+        with self.assertRaisesRegex(ValueError,
+                                    "output lists should be equal to world size"):
+            c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
+
+        # Output is not a list of lists.
+        dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]
+        with self.assertRaisesRegex(RuntimeError,
+                                    "Invalid function argument.*output_tensor_lists"):
             c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
 
 
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 29c7035..18bff0f 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -1426,27 +1426,33 @@
                     ] for rank_iter in group
                 ]
                 assert self._run_all_gather_coalesced_and_verify(
-                    output_tensor_lists, input_tensors, expected_tensors, group_id)
+                    output_tensor_lists, input_tensors,
+                    expected_tensors, group_id
+                ), "output tensors do not match expected ouputs"
 
         self._barrier()
 
-    @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
-    def test_all_gather_coalesced(self):
+    @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+    @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
+    def test_all_gather_coalesced_simple(self):
         group, group_id, rank = self._init_global_test()
         self._test_all_gather_coalesced_helper(group, group_id, rank)
 
     @skip_if_small_worldsize
-    @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
+    @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+    @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
     def test_all_gather_coalesced_group(self):
         group, group_id, rank = self._init_group_test()
         self._test_all_gather_coalesced_helper(group, group_id, rank)
 
-    @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
+    @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+    @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
     def test_all_gather_coalesced_full_group(self):
         group, group_id, rank = self._init_full_group_test()
         self._test_all_gather_coalesced_helper(group, group_id, rank)
 
-    @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
+    @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+    @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
     def test_all_gather_coalesced_with_empty(self):
         group, group_id, rank = self._init_global_test()
         input_tensors = [
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index c9887df..e9c6657 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -318,6 +318,14 @@
               py::call_guard<py::gil_scoped_release>())
 
           .def(
+              "allgather_coalesced",
+              &::c10d::ProcessGroup::allgather_coalesced,
+              py::arg("output_lists"),
+              py::arg("input_list"),
+              py::arg("opts") = ::c10d::AllgatherOptions(),
+              py::call_guard<py::gil_scoped_release>())
+
+          .def(
               "gather",
               &::c10d::ProcessGroup::gather,
               py::arg("output_tensors"),
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index e60120c..bf43726 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1157,32 +1157,6 @@
     else:
         work.wait()
 
-def _check_lists_of_same_shape(list1, list2):
-    """
-    Helper to check that individual tensors of 'list1' and 'list2' match in
-    shape point-wise.
-
-    Returns:
-        True if list1 matches list2
-    """
-    return (len(list1) == len(list2) and
-            all(t1.size() == t2.size() for t1, t2 in zip(list1, list2)))
-
-
-def _check_input_output(output_tensor_lists, input_tensor_list, param_name):
-    """
-    Helper to check that each element of output_tensor_lists matches input_tensor_list in shape.
-
-    """
-    if not isinstance(output_tensor_lists, list):
-        raise RuntimeError("Invalid function argument. Expected parameter `{}` "
-                           "to be of type List[List[torch.Tensor]].".format(param_name))
-    all(_check_tensor_list(
-        output_tensor_list, param_name + " element ") for output_tensor_list in output_tensor_lists)
-    if not all(_check_lists_of_same_shape(
-            output_tensor_list, input_tensor_list) for output_tensor_list in output_tensor_lists):
-        raise RuntimeError("Shape tensor mismatch in {}".format(param_name))
-
 def all_gather_coalesced(output_tensor_lists,
                          input_tensor_list,
                          group=group.WORLD,
@@ -1197,7 +1171,6 @@
             current process. At least one tensor has to be non empty.
         group (ProcessGroup, optional): The process group to work on
         async_op (bool, optional): Whether this op should be an async op.
-            currently does not support async_op being True
 
     Returns:
         Async work handle, if async_op is set to True.
@@ -1228,41 +1201,28 @@
     performance improvements but users of this function should take extra care
     to ensure that each node passes in tensors whose shapes match across nodes.
     """
-    _check_tensor_list(input_tensor_list, "tensor_list")
-    _check_input_output(
-        output_tensor_lists, input_tensor_list, "output_tensor_lists")
+    # We only check basic compatibility with C++ params here, C++ code will
+    # do shape and type checking.
     if _rank_not_in_group(group):
         return
-
-    # Flatten the input and create a list of flat outputs.
-    input_coalesced = torch.cat([t.flatten() for t in input_tensor_list])
-    output_coalesced = [
-        torch.empty(input_coalesced.numel()) for _ in output_tensor_lists]
+    _check_tensor_list(input_tensor_list, "tensor_list")
+    if not isinstance(output_tensor_lists, list):
+        RuntimeError("Invalid function argument: "
+                     "output_tensor_lists should be a list")
+    for output_tensor_list in output_tensor_lists:
+        _check_tensor_list(output_tensor_list, "output_tensor_lists")
 
     if group == GroupMember.WORLD:
         _check_default_pg()
-        work = _default_pg.allgather([output_coalesced], [input_coalesced])
+        work = _default_pg.allgather_coalesced(
+            output_tensor_lists, input_tensor_list)
     else:
-        work = group.allgather([output_coalesced], [input_coalesced])
+        work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
 
     if async_op:
-        raise RuntimeError("all_gather_coalesced does not support "
-                           "async mode yet.")
+        return work
     else:
         work.wait()
-        assert len(output_tensor_lists) == len(output_coalesced)
-        # Iterate through flat outputs and store them in output_tensor_lists
-        for rank, rank_output in enumerate(output_coalesced):
-            current_element = 0
-            for index_output, output_tensor in enumerate(
-                    output_tensor_lists[rank]):
-                output_tensor.copy_(
-                    rank_output.narrow(
-                        dim=0,
-                        start=current_element,
-                        length=output_tensor.numel()).reshape(
-                            output_tensor.size()))
-                current_element += output_tensor.numel()
 
 
 def gather(tensor,
diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp
index 8ca8d64..6718385 100644
--- a/torch/lib/c10d/ProcessGroup.hpp
+++ b/torch/lib/c10d/ProcessGroup.hpp
@@ -124,6 +124,11 @@
       std::vector<at::Tensor>& inputTensors,
       const AllgatherOptions& opts = AllgatherOptions()) = 0;
 
+  virtual std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+      std::vector<std::vector<at::Tensor>>& outputTensorLists,
+      std::vector<at::Tensor>& inputTensors,
+      const AllgatherOptions& opts = AllgatherOptions()) = 0;
+
   virtual std::shared_ptr<ProcessGroup::Work> gather(
       std::vector<std::vector<at::Tensor>>& outputTensors,
       std::vector<at::Tensor>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index 9a3b763..d0f3c48 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -1745,6 +1745,126 @@
 
 namespace {
 
+class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
+ public:
+  AsyncAllgatherCoalescedWork(
+      const std::shared_ptr<gloo::Context>& context,
+      std::vector<std::vector<at::Tensor>>& output_lists,
+      std::vector<at::Tensor>& input_list,
+      uint32_t tag)
+      : context(context),
+        output_lists(output_lists),
+        input_list(input_list),
+        tag(tag) {}
+
+  std::shared_ptr<gloo::Context> context;
+  std::vector<std::vector<at::Tensor>> output_lists;
+  std::vector<at::Tensor> input_list;
+  const uint32_t tag;
+
+  void allgather_coalesced() {
+    assert(!output_lists.empty());
+    assert(!output_lists[0].empty());
+    assert(!input_list.empty());
+
+    const auto& scalarType = input_list[0].scalar_type();
+    gloo::AllgatherOptions opts(context);
+    opts.setTag(tag);
+
+    // Use single flattened input tensor.
+    at::Tensor flatInputTensor = flattenDenseTensors(input_list);
+    GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor);
+
+    // Compute total number of elements we need to allocate for all tensors
+    // requested.
+    int64_t output_numel = 0;
+    for (const auto& t : output_lists[0]) {
+      output_numel += t.numel();
+    }
+    output_numel *= output_lists.size();
+    // Use single flat output tensor.
+    at::Tensor flatOutputTensor =
+        at::empty({output_numel}, output_lists[0][0].options());
+    GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
+    gloo::allgather(opts);
+
+    int64_t current_element = 0;
+    for (auto& output_list : output_lists) {
+      for (auto& output_tensor : output_list) {
+        output_tensor.copy_(
+            flatOutputTensor.narrow(0, current_element, output_tensor.numel())
+                .reshape(output_tensor.sizes()),
+            true);
+        current_element += output_tensor.numel();
+      }
+    }
+  }
+
+  void run() override {
+    allgather_coalesced();
+  }
+};
+
+} // namespace
+
+std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather_coalesced(
+    std::vector<std::vector<at::Tensor>>& output_lists,
+    std::vector<at::Tensor>& input_list,
+    const AllgatherOptions& /* unused */) {
+  static auto invalidArgument = [](const std::string& msg) {
+    throw std::invalid_argument(
+        "ProcessGroupGloo::allgather_coalesced: " + msg);
+  };
+
+  if (input_list.empty()) {
+    invalidArgument("requires non-empty input tensor list");
+  }
+
+  if (output_lists.size() != getSize()) {
+    invalidArgument("output lists should be equal to world size");
+  }
+
+  assertSameDevice(invalidArgument, input_list);
+
+  // Expect i'th tensor of each list from 'output_lists' match i'th tensor
+  // from 'input_list' in type and size.
+  for (const auto& output_list : output_lists) {
+    if (output_list.size() != input_list.size()) {
+      invalidArgument(
+          "invalid output size: (expected length " +
+          std::to_string(input_list.size()) + ", got " +
+          std::to_string(output_list.size()) + ")");
+    }
+    for (int i = 0; i < output_list.size(); ++i) {
+      const auto expected = input_list[i].sizes();
+      const auto actual = output_list[i].sizes();
+      if (actual != expected) {
+        invalidArgument(
+            "invalid size of output tensor at index " + std::to_string(i) +
+            " (expected length " + toString(expected) + ", got " +
+            toString(actual) + ")");
+      }
+      if (input_list[i].type() != output_list[i].type()) {
+        invalidArgument(
+            "invalid tensor type at index " + std::to_string(i) +
+            " (expected " + input_list[i].type().toString() + ", got " +
+            output_list[i].type().toString() + ")");
+      }
+    }
+  }
+
+  assertDense(invalidArgument, input_list);
+
+  auto tag = nextTag();
+  auto context = getContext(tag);
+  auto work = std::make_shared<AsyncAllgatherCoalescedWork>(
+      std::move(context), output_lists, input_list, tag);
+  enqueue(work);
+  return work;
+}
+
+namespace {
+
 class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
  public:
   AsyncGatherWork(
diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp
index 058bf90..ead2ed5 100644
--- a/torch/lib/c10d/ProcessGroupGloo.hpp
+++ b/torch/lib/c10d/ProcessGroupGloo.hpp
@@ -175,6 +175,11 @@
       std::vector<at::Tensor>& inputs,
       const AllgatherOptions& opts = AllgatherOptions()) override;
 
+  std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+      std::vector<std::vector<at::Tensor>>& output_lists,
+      std::vector<at::Tensor>& input_list,
+      const AllgatherOptions& opts = AllgatherOptions()) override;
+
   std::shared_ptr<ProcessGroup::Work> gather(
       std::vector<std::vector<at::Tensor>>& outputs,
       std::vector<at::Tensor>& inputs,
diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp
index 237f436..80ea46a 100644
--- a/torch/lib/c10d/ProcessGroupMPI.cpp
+++ b/torch/lib/c10d/ProcessGroupMPI.cpp
@@ -434,6 +434,14 @@
   return enqueue(std::move(entry));
 }
 
+std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather_coalesced(
+    std::vector<std::vector<at::Tensor>>& /* unused */,
+    std::vector<at::Tensor>& /* unused */,
+    const AllgatherOptions& /* unused */) {
+  throw std::runtime_error(
+      "ProcessGroupMPI does not support allgather_coalesced");
+}
+
 std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
     std::vector<std::vector<at::Tensor>>& outputTensors,
     std::vector<at::Tensor>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp
index 68db535..a3f6986 100644
--- a/torch/lib/c10d/ProcessGroupMPI.hpp
+++ b/torch/lib/c10d/ProcessGroupMPI.hpp
@@ -128,6 +128,11 @@
       std::vector<at::Tensor>& inputTensors,
       const AllgatherOptions& opts = AllgatherOptions()) override;
 
+  std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+      std::vector<std::vector<at::Tensor>>& outputTensorLists,
+      std::vector<at::Tensor>& inputTensors,
+      const AllgatherOptions& opts = AllgatherOptions()) override;
+
   std::shared_ptr<ProcessGroup::Work> gather(
       std::vector<std::vector<at::Tensor>>& outputTensors,
       std::vector<at::Tensor>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp
index f26b42b..e60786c 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.cpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.cpp
@@ -705,6 +705,14 @@
       });
 }
 
+std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather_coalesced(
+    std::vector<std::vector<at::Tensor>>& /* unused */,
+    std::vector<at::Tensor>& /* unused */,
+    const AllgatherOptions& /* unused */) {
+  throw std::runtime_error(
+      "ProcessGroupNCCL does not support allgather_coalesced");
+}
+
 std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
     std::vector<at::Tensor>& outputTensors,
     std::vector<std::vector<at::Tensor>>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp
index 64bdea4..dab1910 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.hpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.hpp
@@ -177,6 +177,11 @@
       std::vector<at::Tensor>& inputTensors,
       const AllgatherOptions& opts = AllgatherOptions()) override;
 
+  std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+      std::vector<std::vector<at::Tensor>>& outputTensorLists,
+      std::vector<at::Tensor>& inputTensors,
+      const AllgatherOptions& opts = AllgatherOptions()) override;
+
   std::shared_ptr<ProcessGroup::Work> reduce_scatter(
       std::vector<at::Tensor>& outputTensors,
       std::vector<std::vector<at::Tensor>>& inputTensors,
diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp
index 374117d..4a81c00 100644
--- a/torch/lib/c10d/Utils.hpp
+++ b/torch/lib/c10d/Utils.hpp
@@ -203,6 +203,20 @@
   }
 }
 
+inline void assertSameDevice(
+    std::function<void(const std::string&)> fn,
+    const at::ArrayRef<at::Tensor>& tensors) {
+  if (tensors.size() < 2) {
+    return;
+  }
+  const auto& device = tensors[0].device();
+  for (int i = 1; i < tensors.size(); ++i) {
+    if (tensors[i].device() != device) {
+      fn("tensors should be on the same device");
+    }
+  }
+}
+
 inline void assertTypeAndSizesMatch(
     std::function<void(const std::string&)> fn,
     const at::ArrayRef<at::Tensor>& tensors,