DDP: 10% of NCCL backend perf improvements with mixed-prec support (#5064)

diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index e40dfb1..853dc90 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -112,9 +112,14 @@
         self.output_device = output_device
         self.broadcast_buffers = broadcast_buffers
 
+        # Flag used by the NCCL backend to make sure we only reduce gradients
+        # one time in the execution engine
+        self.need_reduction = False
+
         MB = 1024 * 1024
         # used for intra-node param sync and inter-node sync as well
         self.broadcast_bucket_size = 10 * MB
+        self.nccl_reduce_bucket_size = 256 * MB
 
         # Sync params and buffers
         module_states = list(self.module.state_dict().values())
@@ -135,11 +140,15 @@
         else:
             self._module_copies = [self.module]
 
-        # Currently NCCL backend only supports single reduction thread/bucket
+        # For NCCL backend, since every single NCCL call is asynchoronous, we
+        # therefore directly enqueue all the NCCL reduction calls to the
+        # default CUDA stream without spawning up other reduction threads.
+        # This achieves the best performance.
         if dist._backend == dist.dist_backend.NCCL:
-            bucket_bytes_cap = float('inf')
-        else:
-            bucket_bytes_cap = 1 * MB
+            self._register_nccl_grad_hook()
+            return
+
+        bucket_bytes_cap = 1 * MB
 
         # This is a triply-nested list where the "dimensions" are: devices, buckets, bucket_elems
         param_buckets = []
@@ -149,7 +158,6 @@
 
         self.bucket_sizes = []
         self.bucket_map = {}
-        param_types = set()
 
         # We transpose param_buckets, so the loop is over buckets.
         # param_buckets_tuple is a doubly-nested list with "dims": devices, bucket_elems
@@ -161,10 +169,8 @@
                 if idx == 0:
                     # Bucket parameter type tracking
                     bucket_param_type = param_tuple[0].type()
-                    param_types.add(bucket_param_type)
                     # Only gloo and nccl support half-precision
                     if bucket_param_type == torch.cuda.HalfTensor and \
-                            dist._backend != dist.dist_backend.NCCL and \
                             dist._backend != dist.dist_backend.GLOO:
                         raise RuntimeError("DistributedDataParallel currently only "
                                            "supports half precision parameters "
@@ -175,13 +181,6 @@
                     self.bucket_map[p] = bucket_idx
                 self.bucket_sizes[bucket_idx] += 1
 
-        # TODO, adding mixed precision support in NCCL reduction code path
-        # This is because NCCL backend doesn't support multiple reduction
-        # bucket.
-        if len(param_types) > 1 and dist._backend == dist.dist_backend.NCCL:
-            raise RuntimeError("DistributedDataParallel currently doesn't "
-                               "support mixed precision type for NCCL backend")
-
         self.buckets = [[[] for _ in range(len(self.device_ids))] for _ in range(len(self.bucket_sizes))]
         self.bucket_events = [[None] * len(self.device_ids) for _ in range(len(self.bucket_sizes))]
         self.reduced = [False] * len(self.bucket_sizes)
@@ -193,16 +192,22 @@
 
     def __getstate__(self):
         attrs = copy.copy(self.__dict__)
-        del attrs['_grad_accs'], attrs['_reduction_queues'], attrs['_reduction_streams'], \
-            attrs['_reduction_threads'], attrs['_nccl_streams'], attrs['_default_streams']
+        if dist._backend != dist.dist_backend.NCCL:
+            del attrs['_grad_accs'], attrs['_reduction_queues'], \
+                attrs['_reduction_streams'], attrs['_reduction_threads'], \
+                attrs['_nccl_streams'], attrs['_default_streams']
         return attrs
 
     def __setstate__(self, state):
         super(DistributedDataParallel, self).__setstate__(state)
-        self._register_grad_hooks()
-        self._start_reduction_threads()
+        if dist._backend == dist.dist_backend.NCCL:
+            self._register_nccl_grad_hook()
+        else:
+            self._register_grad_hooks()
+            self._start_reduction_threads()
 
     def forward(self, *inputs, **kwargs):
+        self.need_reduction = True
         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
         self._sync_params()
         if len(self.device_ids) == 1:
@@ -274,7 +279,86 @@
                     grad_acc.register_hook(self._make_param_hook(p, device_idx))
                     self._grad_accs.append(grad_acc)
 
+    def _register_nccl_grad_hook(self):
+        """
+        This function registers the callback all-reduction function for the
+        NCCL backend. All gradients will be all reduced in one single step.
+        The NCCL reduction will directly be enqueued into the
+        default CUDA stream. Therefore, no synchronization is needed.
+        """
+        # Creating a new group
+        self.nccl_reduction_group_id = dist.new_group()
+
+        def reduction_fn_nccl():
+            # This function only needs to be called once
+            if not self.need_reduction:
+                return
+
+            self.need_reduction = False
+            all_grads = [[] for _ in range(len(self._module_copies))]
+            all_grads_buckets_iters = []
+
+            # Bucketing all the gradients
+            for dev_idx, module in enumerate(self._module_copies):
+                for param in module.parameters():
+                    if not param.requires_grad or param.grad is None:
+                        continue
+                    if param.grad.requires_grad:
+                        raise RuntimeError("DistributedDataParallel only works "
+                                           "with gradients that don't require "
+                                           "grad")
+                    # Adding the gradients for reduction
+                    all_grads[dev_idx].append(param.grad.data)
+
+                # Now bucketing the parameters
+                dev_grads_buckets = _take_tensors(all_grads[dev_idx],
+                                                  self.nccl_reduce_bucket_size)
+
+                all_grads_buckets_iters.append(dev_grads_buckets)
+
+            # Now reduce each bucket one after another
+            for grads_batch in zip(*all_grads_buckets_iters):
+                grads_batch_coalesced = []
+                # Coalesce each bucket
+                for dev_idx, dev_grads_batch in enumerate(grads_batch):
+                    dev_id = self.device_ids[dev_idx]
+                    with torch.cuda.device(dev_id):
+                        dev_grads_batch_coalesced = _flatten_dense_tensors(dev_grads_batch)
+                        grads_batch_coalesced.append(dev_grads_batch_coalesced)
+
+                # We will only use device 0's results, but this single op should be
+                # faster than doing the following two operation sequentially:
+                # (1) intra-node reduce to lead GPU, followed by
+                # (2) inter-node allreduce for all the first lead GPUs in all nodes
+                dist.all_reduce_multigpu(grads_batch_coalesced,
+                                         group=self.nccl_reduction_group_id)
+
+                # Now only work on the first device of self.device_ids, uncoalesce
+                # the gradients for each bucket
+                grads_batch_coalesced[0] /= dist.get_world_size()
+                grads_batch_reduced = _unflatten_dense_tensors(grads_batch_coalesced[0], grads_batch[0])
+                for grad, reduced in zip(grads_batch[0], grads_batch_reduced):
+                    grad.copy_(reduced)
+
+            # clear the gradients and save memory for replicas
+            for module in self._module_copies[1:]:
+                for param in module.parameters():
+                    if param.requires_grad:
+                        param.grad = None
+                        param.data.set_()
+
+        # Now register the reduction hook on the parameters
+        for p in self.module.parameters():
+            if not p.requires_grad:
+                continue
+
+            def allreduce_hook(*unused):
+                Variable._execution_engine.queue_callback(reduction_fn_nccl)
+
+            p.register_hook(allreduce_hook)
+
     def _make_param_hook(self, param, device_idx):
+
         bucket_idx = self.bucket_map[param]
 
         def distributed_data_parallel_hook(*unused):
@@ -349,10 +433,7 @@
             # We only use the first device for distributed reductions
             dist._register_stream(reduction_streams[0])
 
-            if dist._backend == dist.dist_backend.NCCL:
-                group_id = dist.group.WORLD
-            else:
-                group_id = dist.new_group()
+            group_id = dist.new_group()
 
             self._reduction_threads.append(threading.Thread(
                 target=self._reduction_thread_fn,