Put sparse all reduce results to input tensors (#32226)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/32226

right now if users call torch.dist.all_reduce() on dense tensors, outputs are put in input tensors. but if users call torch.dist.all_reduce() on sparse tensors, outputs are neither returned explicitly to users nor are put in input tensors.

To make torch.dist.all_reduce() API have same behavior on both dense tensors and sparse tensors, this diff is made to make torch.dist.all_reduce() on sparse tensors to put output in input tensors as well. This is acheived by simply calling input_sparse.copy_(output_sparse), see PR https://github.com/pytorch/pytorch/pull/9005 that implemented copy_ for sparse tensors.

close #31413
ghstack-source-id: 96984228

Test Plan: unit test

Differential Revision: D19192952

fbshipit-source-id: 2dd31dc057f20cc42b44b9e55df864afa2918c33
diff --git a/test/common_distributed.py b/test/common_distributed.py
index bcf0ad6..acb3ef5 100644
--- a/test/common_distributed.py
+++ b/test/common_distributed.py
@@ -14,6 +14,7 @@
 import torch
 import torch.distributed as c10d
 
+from functools import partial, reduce
 from common_utils import TestCase, TEST_WITH_ROCM
 
 TestSkip = namedtuple('TestSkip', 'exit_code, message')
@@ -110,6 +111,49 @@
     return TIMEOUT_OVERRIDE.get(test_id.split('.')[-1], TIMEOUT_DEFAULT)
 
 
+def simple_sparse_reduce_tests(rank, world_size, num_inputs=1):
+    """
+    Generate a number of basic test cases for sparse reduction.
+    These cover tensors with a varying number of sparse dimensions and a varying
+    number of dense dimensions. The only reduction operation we support is sum.
+    """
+    def generate(rank, world_size, sparse_dims=1, dense_dims=0):
+        # First sparse dimension is [0..rank].
+        # Subsequent dimensions are always 0, so we know there is
+        # a non-empty intersection between any two sparse tensors.
+        indices = [range(rank + 1)]
+        shape = [world_size] + [2 for _ in range(dense_dims)]
+        for _ in range(sparse_dims - 1):
+            indices.append([0] * (rank + 1))
+            shape.append(world_size)
+        values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
+        return torch.sparse_coo_tensor(indices, values, shape)
+
+    def compute_sum(fn, world_size):
+        return reduce(lambda a, b: a + b, [fn(rank, world_size) for rank in range(world_size)])
+
+    return [
+        (
+            [
+                fn(num_inputs * rank + i, num_inputs * world_size)
+                for i in range(num_inputs)
+            ],
+            [
+                compute_sum(fn, num_inputs * world_size)
+                for i in range(num_inputs)
+            ],
+        )
+        for fn in [
+            partial(generate, sparse_dims=1),
+            partial(generate, sparse_dims=2),
+            partial(generate, sparse_dims=3),
+            partial(generate, dense_dims=1),
+            partial(generate, dense_dims=2),
+            partial(generate, dense_dims=3),
+        ]
+    ]
+
+
 class MultiProcessTestCase(TestCase):
     MAIN_PROCESS_RANK = -1
     # This exit code is used to indicate that the test code had an error and
diff --git a/test/test_c10d.py b/test/test_c10d.py
index dc43d10..73f1892 100644
--- a/test/test_c10d.py
+++ b/test/test_c10d.py
@@ -14,7 +14,7 @@
 from sys import platform
 
 from itertools import groupby
-from functools import partial, reduce
+from functools import reduce
 import operator
 
 import torch
@@ -28,7 +28,8 @@
 
 from common_distributed import MultiProcessTestCase, \
     requires_gloo, requires_nccl, requires_nccl_version, \
-    skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout, skip_if_rocm
+    skip_if_not_multigpu, skip_if_lt_x_gpu, skip_for_known_issues, get_timeout, skip_if_rocm, \
+    simple_sparse_reduce_tests
 from common_utils import TestCase, load_tests, run_tests, retry_on_address_already_in_use_error, TEST_WITH_TSAN
 
 # load_tests from common_utils is used to automatically filter tests for
@@ -185,49 +186,6 @@
     ]
 
 
-def simple_sparse_reduce_tests(rank, world_size, num_inputs=1):
-    """
-    Generate a number of basic test cases for sparse reduction.
-    These cover tensors with a varying number of sparse dimensions and a varying
-    number of dense dimensions. The only reduction operation we support is sum.
-    """
-    def generate(rank, world_size, sparse_dims=1, dense_dims=0):
-        # First sparse dimension is [0..rank].
-        # Subsequent dimensions are always 0, so we know there is
-        # a non-empty intersection between any two sparse tensors.
-        indices = [range(rank + 1)]
-        shape = [world_size] + [2 for _ in range(dense_dims)]
-        for _ in range(sparse_dims - 1):
-            indices.append([0] * (rank + 1))
-            shape.append(world_size)
-        values = torch.ones([rank + 1] + [2 for _ in range(dense_dims)])
-        return torch.sparse_coo_tensor(indices, values, shape)
-
-    def compute_sum(fn, world_size):
-        return reduce(lambda a, b: a + b, [fn(rank, world_size) for rank in range(world_size)])
-
-    return [
-        (
-            [
-                fn(num_inputs * rank + i, num_inputs * world_size)
-                for i in range(num_inputs)
-            ],
-            [
-                compute_sum(fn, num_inputs * world_size)
-                for i in range(num_inputs)
-            ],
-        )
-        for fn in [
-            partial(generate, sparse_dims=1),
-            partial(generate, sparse_dims=2),
-            partial(generate, sparse_dims=3),
-            partial(generate, dense_dims=1),
-            partial(generate, dense_dims=2),
-            partial(generate, dense_dims=3),
-        ]
-    ]
-
-
 class StoreTestBase(object):
     def _create_store(self, i):
         raise RuntimeError("not implemented")
@@ -944,8 +902,10 @@
                 self.world_size,
                 num_inputs=num_inputs_per_rank)
             for (inputs, outputs) in tests:
-                work = pg.allreduce([fn(input) for input in inputs])
+                tensors = [fn(input) for input in inputs]
+                work = pg.allreduce(tensors)
                 work.wait()
+                self.assertEqual(tensors, outputs)
                 self.assertEqual(work.result(), outputs)
 
     def test_sparse_allreduce_basics(self):
diff --git a/test/test_distributed.py b/test/test_distributed.py
index 4aa2c0d..3766a14 100644
--- a/test/test_distributed.py
+++ b/test/test_distributed.py
@@ -20,6 +20,7 @@
 from common_utils import TestCase, run_tests, skipIfRocm
 from torch._utils_internal import TEST_MASTER_ADDR as MASTER_ADDR
 from torch._utils_internal import TEST_MASTER_PORT as MASTER_PORT
+from common_distributed import simple_sparse_reduce_tests, skip_if_rocm
 
 try:
     import torchvision
@@ -1013,6 +1014,30 @@
             group, group_id, rank, dist.ReduceOp.MAX, -1, 10, 10
         )
 
+    # SPARSE ALL REDUCE
+    def _test_sparse_all_reduce_sum(self, fn):
+        group, group_id, rank = self._init_global_test()
+
+        tests = simple_sparse_reduce_tests(
+            rank,
+            dist.get_world_size(),
+            num_inputs=1)
+        for (inputs, outputs) in tests:
+            tensors = [fn(input) for input in inputs]
+            dist.all_reduce(tensors[0], dist.ReduceOp.SUM, group_id)
+            self.assertEqual(tensors[0], outputs[0])
+
+    @unittest.skipIf(BACKEND != "gloo", "Only Gloo backend support sparse all reduce")
+    def test_sparse_all_reduce_sum(self):
+        self._test_sparse_all_reduce_sum(lambda t: t)
+
+    @unittest.skipIf(BACKEND != "gloo", "Only Gloo backend support sparse all reduce")
+    @skip_if_no_cuda_distributed
+    @skip_if_no_gpu
+    @skip_if_rocm
+    def test_sparse_all_reduce_sum_cuda(self):
+        self._test_sparse_all_reduce_sum(lambda t: t.clone().cuda())
+
     # ALL REDUCE - COALESCED
     @staticmethod
     def _all_reduce_coalesced_sum_test_cases(group_size):
diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp
index 4adadf8..451f47a 100644
--- a/torch/lib/c10d/ProcessGroupGloo.cpp
+++ b/torch/lib/c10d/ProcessGroupGloo.cpp
@@ -976,6 +976,7 @@
     // Copy back to input tensors.
     outputs.reserve(inputs.size());
     for (size_t i = 0; i < inputs.size(); i++) {
+      inputs[i].copy_(output);
       if (output.is_sparse()) {
         outputs.push_back(output.clone());
       } else {
@@ -1210,6 +1211,12 @@
       guard.set_index(inputs[i].device().index());
       events[i].block(at::cuda::getCurrentCUDAStream());
     }
+
+    // Copy outputs back to inputs after synchronization, so that users can
+    // access all reduce results from input tensors
+    for (size_t i = 0; i < inputs.size(); i++) {
+      inputs[i].copy_(outputs[i]);
+    }
   }
 
   std::vector<at::Tensor> tmp;