Moving python allgather_coalesced impl from Py to C. (#29059)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29059
This is a resubmit of reverted diff D18209289 ( PR #28857 ).
Test Plan:
buck test caffe2/test:c10d
buck test caffe2/test:distributed_gloo
Reviewed By: pietern
Differential Revision: D18277097
fbshipit-source-id: aecfd7206d70829f0cac66182bf02fccee410fed
diff --git a/test/test_c10d.py b/test/test_c10d.py
index 7d8d85b..d4b6ac6 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -1309,31 +1309,37 @@
def test_allgather_coalesced_checks(self):
store = c10d.FileStore(self.file_name, self.world_size)
pg = c10d.ProcessGroupGloo(store, self.rank, self.world_size, self.opts())
- dummy_input = [torch.Tensor([1])]
+ dummy_input = [torch.zeros([1], dtype=torch.float32)]
dummy_output_lists = [
- [torch.Tensor([-1])] for _ in range(self.world_size)
+ [torch.zeros([1], dtype=torch.float32)] for _ in range(self.world_size)
]
- with self.assertRaisesRegex(RuntimeError,
- "all_gather_coalesced does not support "
- "async mode yet."):
- c10d.all_gather_coalesced(
- dummy_output_lists, dummy_input, pg, async_op=True)
# One of output tensors does not match input list.
- dummy_output_lists[0] = [torch.Tensor(0)]
- with self.assertRaisesRegex(RuntimeError, "Shape tensor mismatch"):
+ dummy_output_lists[0] = [torch.zeros([0], dtype=torch.float32)]
+ with self.assertRaisesRegex(ValueError,
+ "invalid size of output tensor at index 0"):
c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
- # Output is not a list of lists.
- dummy_output_lists = [torch.Tensor(0)]
- with self.assertRaisesRegex(RuntimeError, "Invalid function argument"):
+ # One of output tensors does not match input list.
+ dummy_output_lists[0] = [torch.zeros([1], dtype=torch.float64)]
+ with self.assertRaisesRegex(ValueError,
+ "invalid tensor type at index 0"):
c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
# Output lists have too many elements
dummy_output_lists = [
- [torch.Tensor([-1])] for _ in range(self.world_size + 1)
+ [
+ torch.zeros([1], dtype=torch.float32)
+ ] for _ in range(self.world_size + 1)
]
- with self.assertRaisesRegex(ValueError, "invalid output tensor"):
+ with self.assertRaisesRegex(ValueError,
+ "output lists should be equal to world size"):
+ c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
+
+ # Output is not a list of lists.
+ dummy_output_lists = [torch.zeros([0], dtype=torch.float32)]
+ with self.assertRaisesRegex(RuntimeError,
+ "Invalid function argument.*output_tensor_lists"):
c10d.all_gather_coalesced(dummy_output_lists, dummy_input, pg)
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 29c7035..18bff0f 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -1426,27 +1426,33 @@
] for rank_iter in group
]
assert self._run_all_gather_coalesced_and_verify(
- output_tensor_lists, input_tensors, expected_tensors, group_id)
+ output_tensor_lists, input_tensors,
+ expected_tensors, group_id
+ ), "output tensors do not match expected ouputs"
self._barrier()
- @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
- def test_all_gather_coalesced(self):
+ @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+ @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
+ def test_all_gather_coalesced_simple(self):
group, group_id, rank = self._init_global_test()
self._test_all_gather_coalesced_helper(group, group_id, rank)
@skip_if_small_worldsize
- @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
+ @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+ @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
def test_all_gather_coalesced_group(self):
group, group_id, rank = self._init_group_test()
self._test_all_gather_coalesced_helper(group, group_id, rank)
- @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
+ @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+ @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
def test_all_gather_coalesced_full_group(self):
group, group_id, rank = self._init_full_group_test()
self._test_all_gather_coalesced_helper(group, group_id, rank)
- @unittest.skipIf(BACKEND == "nccl", "Nccl does not support CPU tensors")
+ @unittest.skipIf(BACKEND == "nccl", "all_gather_coalesced does not support NCCL")
+ @unittest.skipIf(BACKEND == "mpi", "all_gather_coalesced does not support MPI")
def test_all_gather_coalesced_with_empty(self):
group, group_id, rank = self._init_global_test()
input_tensors = [
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index c9887df..e9c6657 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -318,6 +318,14 @@
py::call_guard<py::gil_scoped_release>())
.def(
+ "allgather_coalesced",
+ &::c10d::ProcessGroup::allgather_coalesced,
+ py::arg("output_lists"),
+ py::arg("input_list"),
+ py::arg("opts") = ::c10d::AllgatherOptions(),
+ py::call_guard<py::gil_scoped_release>())
+
+ .def(
"gather",
&::c10d::ProcessGroup::gather,
py::arg("output_tensors"),
diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py
index e60120c..bf43726 100644
--- a/torch/distributed/distributed_c10d.py
+++ b/torch/distributed/distributed_c10d.py
@@ -1157,32 +1157,6 @@
else:
work.wait()
-def _check_lists_of_same_shape(list1, list2):
- """
- Helper to check that individual tensors of 'list1' and 'list2' match in
- shape point-wise.
-
- Returns:
- True if list1 matches list2
- """
- return (len(list1) == len(list2) and
- all(t1.size() == t2.size() for t1, t2 in zip(list1, list2)))
-
-
-def _check_input_output(output_tensor_lists, input_tensor_list, param_name):
- """
- Helper to check that each element of output_tensor_lists matches input_tensor_list in shape.
-
- """
- if not isinstance(output_tensor_lists, list):
- raise RuntimeError("Invalid function argument. Expected parameter `{}` "
- "to be of type List[List[torch.Tensor]].".format(param_name))
- all(_check_tensor_list(
- output_tensor_list, param_name + " element ") for output_tensor_list in output_tensor_lists)
- if not all(_check_lists_of_same_shape(
- output_tensor_list, input_tensor_list) for output_tensor_list in output_tensor_lists):
- raise RuntimeError("Shape tensor mismatch in {}".format(param_name))
-
def all_gather_coalesced(output_tensor_lists,
input_tensor_list,
group=group.WORLD,
@@ -1197,7 +1171,6 @@
current process. At least one tensor has to be non empty.
group (ProcessGroup, optional): The process group to work on
async_op (bool, optional): Whether this op should be an async op.
- currently does not support async_op being True
Returns:
Async work handle, if async_op is set to True.
@@ -1228,41 +1201,28 @@
performance improvements but users of this function should take extra care
to ensure that each node passes in tensors whose shapes match across nodes.
"""
- _check_tensor_list(input_tensor_list, "tensor_list")
- _check_input_output(
- output_tensor_lists, input_tensor_list, "output_tensor_lists")
+ # We only check basic compatibility with C++ params here, C++ code will
+ # do shape and type checking.
if _rank_not_in_group(group):
return
-
- # Flatten the input and create a list of flat outputs.
- input_coalesced = torch.cat([t.flatten() for t in input_tensor_list])
- output_coalesced = [
- torch.empty(input_coalesced.numel()) for _ in output_tensor_lists]
+ _check_tensor_list(input_tensor_list, "tensor_list")
+ if not isinstance(output_tensor_lists, list):
+ RuntimeError("Invalid function argument: "
+ "output_tensor_lists should be a list")
+ for output_tensor_list in output_tensor_lists:
+ _check_tensor_list(output_tensor_list, "output_tensor_lists")
if group == GroupMember.WORLD:
_check_default_pg()
- work = _default_pg.allgather([output_coalesced], [input_coalesced])
+ work = _default_pg.allgather_coalesced(
+ output_tensor_lists, input_tensor_list)
else:
- work = group.allgather([output_coalesced], [input_coalesced])
+ work = group.allgather_coalesced(output_tensor_lists, input_tensor_list)
if async_op:
- raise RuntimeError("all_gather_coalesced does not support "
- "async mode yet.")
+ return work
else:
work.wait()
- assert len(output_tensor_lists) == len(output_coalesced)
- # Iterate through flat outputs and store them in output_tensor_lists
- for rank, rank_output in enumerate(output_coalesced):
- current_element = 0
- for index_output, output_tensor in enumerate(
- output_tensor_lists[rank]):
- output_tensor.copy_(
- rank_output.narrow(
- dim=0,
- start=current_element,
- length=output_tensor.numel()).reshape(
- output_tensor.size()))
- current_element += output_tensor.numel()
def gather(tensor,
diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp
index 8ca8d64..6718385 100644
--- a/torch/lib/c10d/ProcessGroup.hpp
+++ b/torch/lib/c10d/ProcessGroup.hpp
@@ -124,6 +124,11 @@
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) = 0;
+ virtual std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& outputTensorLists,
+ std::vector<at::Tensor>& inputTensors,
+ const AllgatherOptions& opts = AllgatherOptions()) = 0;
+
virtual std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index 9a3b763..d0f3c48 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -1745,6 +1745,126 @@
namespace {
+class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
+ public:
+ AsyncAllgatherCoalescedWork(
+ const std::shared_ptr<gloo::Context>& context,
+ std::vector<std::vector<at::Tensor>>& output_lists,
+ std::vector<at::Tensor>& input_list,
+ uint32_t tag)
+ : context(context),
+ output_lists(output_lists),
+ input_list(input_list),
+ tag(tag) {}
+
+ std::shared_ptr<gloo::Context> context;
+ std::vector<std::vector<at::Tensor>> output_lists;
+ std::vector<at::Tensor> input_list;
+ const uint32_t tag;
+
+ void allgather_coalesced() {
+ assert(!output_lists.empty());
+ assert(!output_lists[0].empty());
+ assert(!input_list.empty());
+
+ const auto& scalarType = input_list[0].scalar_type();
+ gloo::AllgatherOptions opts(context);
+ opts.setTag(tag);
+
+ // Use single flattened input tensor.
+ at::Tensor flatInputTensor = flattenDenseTensors(input_list);
+ GENERATE_ALL_TYPES(scalarType, setInput, opts, flatInputTensor);
+
+ // Compute total number of elements we need to allocate for all tensors
+ // requested.
+ int64_t output_numel = 0;
+ for (const auto& t : output_lists[0]) {
+ output_numel += t.numel();
+ }
+ output_numel *= output_lists.size();
+ // Use single flat output tensor.
+ at::Tensor flatOutputTensor =
+ at::empty({output_numel}, output_lists[0][0].options());
+ GENERATE_ALL_TYPES(scalarType, setOutput, opts, flatOutputTensor);
+ gloo::allgather(opts);
+
+ int64_t current_element = 0;
+ for (auto& output_list : output_lists) {
+ for (auto& output_tensor : output_list) {
+ output_tensor.copy_(
+ flatOutputTensor.narrow(0, current_element, output_tensor.numel())
+ .reshape(output_tensor.sizes()),
+ true);
+ current_element += output_tensor.numel();
+ }
+ }
+ }
+
+ void run() override {
+ allgather_coalesced();
+ }
+};
+
+} // namespace
+
+std::shared_ptr<ProcessGroup::Work> ProcessGroupGloo::allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& output_lists,
+ std::vector<at::Tensor>& input_list,
+ const AllgatherOptions& /* unused */) {
+ static auto invalidArgument = [](const std::string& msg) {
+ throw std::invalid_argument(
+ "ProcessGroupGloo::allgather_coalesced: " + msg);
+ };
+
+ if (input_list.empty()) {
+ invalidArgument("requires non-empty input tensor list");
+ }
+
+ if (output_lists.size() != getSize()) {
+ invalidArgument("output lists should be equal to world size");
+ }
+
+ assertSameDevice(invalidArgument, input_list);
+
+ // Expect i'th tensor of each list from 'output_lists' match i'th tensor
+ // from 'input_list' in type and size.
+ for (const auto& output_list : output_lists) {
+ if (output_list.size() != input_list.size()) {
+ invalidArgument(
+ "invalid output size: (expected length " +
+ std::to_string(input_list.size()) + ", got " +
+ std::to_string(output_list.size()) + ")");
+ }
+ for (int i = 0; i < output_list.size(); ++i) {
+ const auto expected = input_list[i].sizes();
+ const auto actual = output_list[i].sizes();
+ if (actual != expected) {
+ invalidArgument(
+ "invalid size of output tensor at index " + std::to_string(i) +
+ " (expected length " + toString(expected) + ", got " +
+ toString(actual) + ")");
+ }
+ if (input_list[i].type() != output_list[i].type()) {
+ invalidArgument(
+ "invalid tensor type at index " + std::to_string(i) +
+ " (expected " + input_list[i].type().toString() + ", got " +
+ output_list[i].type().toString() + ")");
+ }
+ }
+ }
+
+ assertDense(invalidArgument, input_list);
+
+ auto tag = nextTag();
+ auto context = getContext(tag);
+ auto work = std::make_shared<AsyncAllgatherCoalescedWork>(
+ std::move(context), output_lists, input_list, tag);
+ enqueue(work);
+ return work;
+}
+
+namespace {
+
class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
public:
AsyncGatherWork(
diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp
index 058bf90..ead2ed5 100644
--- a/torch/lib/c10d/ProcessGroupGloo.hpp
+++ b/torch/lib/c10d/ProcessGroupGloo.hpp
@@ -175,6 +175,11 @@
std::vector<at::Tensor>& inputs,
const AllgatherOptions& opts = AllgatherOptions()) override;
+ std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& output_lists,
+ std::vector<at::Tensor>& input_list,
+ const AllgatherOptions& opts = AllgatherOptions()) override;
+
std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputs,
std::vector<at::Tensor>& inputs,
diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp
index 237f436..80ea46a 100644
--- a/torch/lib/c10d/ProcessGroupMPI.cpp
+++ b/torch/lib/c10d/ProcessGroupMPI.cpp
@@ -434,6 +434,14 @@
return enqueue(std::move(entry));
}
+std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& /* unused */,
+ std::vector<at::Tensor>& /* unused */,
+ const AllgatherOptions& /* unused */) {
+ throw std::runtime_error(
+ "ProcessGroupMPI does not support allgather_coalesced");
+}
+
std::shared_ptr<ProcessGroup::Work> ProcessGroupMPI::gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp
index 68db535..a3f6986 100644
--- a/torch/lib/c10d/ProcessGroupMPI.hpp
+++ b/torch/lib/c10d/ProcessGroupMPI.hpp
@@ -128,6 +128,11 @@
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
+ std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& outputTensorLists,
+ std::vector<at::Tensor>& inputTensors,
+ const AllgatherOptions& opts = AllgatherOptions()) override;
+
std::shared_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp
index f26b42b..e60786c 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.cpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.cpp
@@ -705,6 +705,14 @@
});
}
+std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& /* unused */,
+ std::vector<at::Tensor>& /* unused */,
+ const AllgatherOptions& /* unused */) {
+ throw std::runtime_error(
+ "ProcessGroupNCCL does not support allgather_coalesced");
+}
+
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp
index 64bdea4..dab1910 100644
--- a/torch/lib/c10d/ProcessGroupNCCL.hpp
+++ b/torch/lib/c10d/ProcessGroupNCCL.hpp
@@ -177,6 +177,11 @@
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
+ std::shared_ptr<ProcessGroup::Work> allgather_coalesced(
+ std::vector<std::vector<at::Tensor>>& outputTensorLists,
+ std::vector<at::Tensor>& inputTensors,
+ const AllgatherOptions& opts = AllgatherOptions()) override;
+
std::shared_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
diff --git a/torch/lib/c10d/Utils.hpp b/torch/lib/c10d/Utils.hpp
index 374117d..4a81c00 100644
--- a/torch/lib/c10d/Utils.hpp
+++ b/torch/lib/c10d/Utils.hpp
@@ -203,6 +203,20 @@
}
}
+inline void assertSameDevice(
+ std::function<void(const std::string&)> fn,
+ const at::ArrayRef<at::Tensor>& tensors) {
+ if (tensors.size() < 2) {
+ return;
+ }
+ const auto& device = tensors[0].device();
+ for (int i = 1; i < tensors.size(); ++i) {
+ if (tensors[i].device() != device) {
+ fn("tensors should be on the same device");
+ }
+ }
+}
+
inline void assertTypeAndSizesMatch(
std::function<void(const std::string&)> fn,
const at::ArrayRef<at::Tensor>& tensors,