Put sparse all reduce results to input tensors (#32226)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32226
right now if users call torch.dist.all_reduce() on dense tensors, outputs are put in input tensors. but if users call torch.dist.all_reduce() on sparse tensors, outputs are neither returned explicitly to users nor are put in input tensors.
To make torch.dist.all_reduce() API have same behavior on both dense tensors and sparse tensors, this diff is made to make torch.dist.all_reduce() on sparse tensors to put output in input tensors as well. This is acheived by simply calling input_sparse.copy_(output_sparse), see PR https://github.com/pytorch/pytorch/pull/9005 that implemented copy_ for sparse tensors.
close #31413
ghstack-source-id: 96984228
Test Plan: unit test
Differential Revision: D19192952
fbshipit-source-id: 2dd31dc057f20cc42b44b9e55df864afa2918c33
diff --git a/test/common_distributed.py b/test/common_distributed.py
index bcf0ad6..acb3ef5 100644
--- a/test/common_distributed.py
+++ b/test/common_distributed.py
@@ -14,6 +14,7 @@
import torch
import torch.distributed as c10d
+from functools import partial, reduce
from common_utils import TestCase, TEST_WITH_ROCM
TestSkip = namedtuple('TestSkip', 'exit_code, message')
@@ -110,6 +111,49 @@
return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
+def simple_sparse_reduce_tests(rank, world_size, num_inputs=1):
+ """
+ Generate a number of basic test cases for sparse reduction.
+ These cover tensors with a varying number of sparse dimensions and a varying
+ number of dense dimensions. The only reduction operation we support is sum.
+ """
+ def generate(rank, world_size, sparse_dims=1, dense_dims=0):
+ # First sparse dimension is [0..rank].
+ # Subsequent dimensions are always 0, so we know there is
+ # a non-empty intersection between any two sparse tensors.
+ indices = [range(rank + 1)]
+ shape = [world_size] + [2 for _ in range(dense_dims)]
+ for _ in range(sparse_dims - 1):
+ indices.append([0] * (rank + 1))
+ shape.append(world_size)
+ values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
+ return torch.sparse_coo_tensor(indices, values, shape)
+
+ def compute_sum(fn, world_size):
+ return reduce(lambda a, b: a + b, [fn(rank, world_size) for rank in range(world_size)])
+
+ return [
+ (
+ [
+ fn(num_inputs * rank + i, num_inputs * world_size)
+ for i in range(num_inputs)
+ ],
+ [
+ compute_sum(fn, num_inputs * world_size)
+ for i in range(num_inputs)
+ ],
+ )
+ for fn in [
+ partial(generate, sparse_dims=1),
+ partial(generate, sparse_dims=2),
+ partial(generate, sparse_dims=3),
+ partial(generate, dense_dims=1),
+ partial(generate, dense_dims=2),
+ partial(generate, dense_dims=3),
+ ]
+ ]
+
+
class MultiProcessTestCase(TestCase):
MAIN_PROCESS_RANK = -1
# This exit code is used to indicate that the test code had an error and
diff --git a/test/test_c10d.py b/test/test_c10d.py
index dc43d10..73f1892 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -14,7 +14,7 @@
from sys import platform
from itertools import groupby
-from functools import partial, reduce
+from functools import reduce
import operator
import torch
@@ -28,7 +28,8 @@
from common_distributed import MultiProcessTestCase, \
requires_gloo, requires_nccl, requires_nccl_version, \
- skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout, skip_if_rocm
+ skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout, skip_if_rocm, \
+ simple_sparse_reduce_tests
from common_utils import TestCase, load_tests, run_tests, retry_on_address_already_in_use_error, TEST_WITH_TSAN
# load_tests from common_utils is used to automatically filter tests for
@@ -185,49 +186,6 @@
]
-def simple_sparse_reduce_tests(rank, world_size, num_inputs=1):
- """
- Generate a number of basic test cases for sparse reduction.
- These cover tensors with a varying number of sparse dimensions and a varying
- number of dense dimensions. The only reduction operation we support is sum.
- """
- def generate(rank, world_size, sparse_dims=1, dense_dims=0):
- # First sparse dimension is [0..rank].
- # Subsequent dimensions are always 0, so we know there is
- # a non-empty intersection between any two sparse tensors.
- indices = [range(rank + 1)]
- shape = [world_size] + [2 for _ in range(dense_dims)]
- for _ in range(sparse_dims - 1):
- indices.append([0] * (rank + 1))
- shape.append(world_size)
- values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
- return torch.sparse_coo_tensor(indices, values, shape)
-
- def compute_sum(fn, world_size):
- return reduce(lambda a, b: a + b, [fn(rank, world_size) for rank in range(world_size)])
-
- return [
- (
- [
- fn(num_inputs * rank + i, num_inputs * world_size)
- for i in range(num_inputs)
- ],
- [
- compute_sum(fn, num_inputs * world_size)
- for i in range(num_inputs)
- ],
- )
- for fn in [
- partial(generate, sparse_dims=1),
- partial(generate, sparse_dims=2),
- partial(generate, sparse_dims=3),
- partial(generate, dense_dims=1),
- partial(generate, dense_dims=2),
- partial(generate, dense_dims=3),
- ]
- ]
-
-
class StoreTestBase(object):
def _create_store(self, i):
raise RuntimeError("not implemented")
@@ -944,8 +902,10 @@
self.world_size,
num_inputs=num_inputs_per_rank)
for (inputs, outputs) in tests:
- work = pg.allreduce([fn(input) for input in inputs])
+ tensors = [fn(input) for input in inputs]
+ work = pg.allreduce(tensors)
work.wait()
+ self.assertEqual(tensors, outputs)
self.assertEqual(work.result(), outputs)
def test_sparse_allreduce_basics(self):
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 4aa2c0d..3766a14 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -20,6 +20,7 @@
from common_utils import TestCase, run_tests, skipIfRocm
from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
+from common_distributed import simple_sparse_reduce_tests, skip_if_rocm
try:
import torchvision
@@ -1013,6 +1014,30 @@
group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
)
+ # SPARSE ALL REDUCE
+ def _test_sparse_all_reduce_sum(self, fn):
+ group, group_id, rank = self._init_global_test()
+
+ tests = simple_sparse_reduce_tests(
+ rank,
+ dist.get_world_size(),
+ num_inputs=1)
+ for (inputs, outputs) in tests:
+ tensors = [fn(input) for input in inputs]
+ dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
+ self.assertEqual(tensors[0], outputs[0])
+
+ @unittest.skipIf(BACKEND != "gloo", "Only Gloo backend support sparse all reduce")
+ def test_sparse_all_reduce_sum(self):
+ self._test_sparse_all_reduce_sum(lambda t: t)
+
+ @unittest.skipIf(BACKEND != "gloo", "Only Gloo backend support sparse all reduce")
+ @skip_if_no_cuda_distributed
+ @skip_if_no_gpu
+ @skip_if_rocm
+ def test_sparse_all_reduce_sum_cuda(self):
+ self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
+
# ALL REDUCE - COALESCED
@staticmethod
def _all_reduce_coalesced_sum_test_cases(group_size):
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index 4adadf8..451f47a 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -976,6 +976,7 @@
// Copy back to input tensors.
outputs.reserve(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
+ inputs[i].copy_(output);
if (output.is_sparse()) {
outputs.push_back(output.clone());
} else {
@@ -1210,6 +1211,12 @@
guard.set_index(inputs[i].device().index());
events[i].block(at::cuda::getCurrentCUDAStream());
}
+
+ // Copy outputs back to inputs after synchronization, so that users can
+ // access all reduce results from input tensors
+ for (size_t i = 0; i < inputs.size(); i++) {
+ inputs[i].copy_(outputs[i]);
+ }
}
std::vector<at::Tensor> tmp;