Make grad point to bucket buffer in DDP to save memory usage (#41954)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41954
Make both variable.grad() and grad in distautograd context point to bucket buffer in DDP to save memory usage.
In this case, grad will be view of bucket buffer tensors, in order to make it compatiable with optimizer.zero_grad(), we
made changes in https://github.com/pytorch/pytorch/pull/41283.

Also be noted that we can not make variable.grad() pointing to bucket buffer during construction time, because we want to
keep grad undefined for unused parameters.
ghstack-source-id: 110260297

Test Plan:
unit tests,

For roberta_base model with ~1GB parameters, peak memory dropped ~1GB (8250MB-7183MB).  Per iteration latency (0.982s ->0.909s), 8% speed up
https://www.internalfb.com/intern/fblearner/details/211713882?tab=operator_details
https://www.internalfb.com/intern/fblearner/details/211772923?tab=operator_details

For resnet model with ~97M parameters, peak memory dropped ~100MB (3089MB -> 2988MB). Per iteration latency has no change (0.122s -> 0.123s)
https://www.internalfb.com/intern/fblearner/details/211713577?tab=operator_details
https://www.internalfb.com/intern/fblearner/details/211712582?tab=operator_details

accuracy benchmark is expected as well
https://www.internalfb.com/intern/fblearner/details/213237067?tab=Outputs

Reviewed By: mrshenli

Differential Revision: D22707857

fbshipit-source-id: b5e767cfb34ccb3d067db2735482a86d59aea7a4
diff --git a/test/distributed/test_distributed.py b/test/distributed/test_distributed.py
index c353c61..e03335c 100644
--- a/test/distributed/test_distributed.py
+++ b/test/distributed/test_distributed.py
@@ -2119,7 +2119,7 @@
                 # Clear gradients manually
                 grad = net.module.weight.grad
                 if grad is not None:
-                    grad.detach_()
+                    grad.requires_grad_(False)
                     grad.zero_()
                 # Forward + BW
                 batch = torch.tensor([rank]).float().cuda(rank)
diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp
index c72c67e..0a714b3 100644
--- a/torch/csrc/autograd/VariableTypeManual.cpp
+++ b/torch/csrc/autograd/VariableTypeManual.cpp
@@ -269,7 +269,11 @@
                    "of detach_(). Alternatively, create this view with an "
                    "`unsafe_` version of the function that produced it.");
     } else {
-      AT_ERROR("Can't detach views in-place. Use detach() instead");
+      AT_ERROR("If you are using DistributedDataParallel (DDP) for training, "
+               "gradients are views of DDP buckets, and hence detach_() cannot "
+               "be called on these gradients. To fix this error, please refer "
+               "to the Optimizer.zero_grad() function "
+               "in torch/optim/optimizer.py as the solution.");
     }
   }
   // I think the choice here is conservative.  In principle, doing
diff --git a/torch/csrc/autograd/functions/accumulate_grad.h b/torch/csrc/autograd/functions/accumulate_grad.h
index e1a02dc..dafd07f 100644
--- a/torch/csrc/autograd/functions/accumulate_grad.h
+++ b/torch/csrc/autograd/functions/accumulate_grad.h
@@ -161,6 +161,11 @@
         // valid operation which adds `new_grad` to `variable_grad` in
         // place. `variable_grad` is thus still referring to the same tensor
         // after the operation.
+        // Also DistributedDataParallel(DDP) package relies on grad being
+        // mutated in place for saving peak memory usage. DDP will still
+        // work correctly if it is mutated out of place here, but DDP will
+        // maintain one extra copy of grad tensors in buffer and thus
+        // increase peak memory usage.
         variable_grad += new_grad;
         CHECK_RESULT(variable_grad, variable);
         // ^ We could enforce the contract more aggressively here by writing:
diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp
index cc35d08..5c7d2a0 100644
--- a/torch/csrc/distributed/c10d/init.cpp
+++ b/torch/csrc/distributed/c10d/init.cpp
@@ -168,8 +168,8 @@
           py::arg("find_unused_parameters") = false,
           py::call_guard<py::gil_scoped_release>())
       .def(
-          "initialize_buckets",
-          &::c10d::Reducer::initialize_buckets,
+          "prepare_forward",
+          &::c10d::Reducer::prepare_forward,
           py::call_guard<py::gil_scoped_release>())
       .def(
           "prepare_for_backward",
diff --git a/torch/csrc/distributed/c10d/reducer.cpp b/torch/csrc/distributed/c10d/reducer.cpp
index 4c509ab..87dbe5a 100644
--- a/torch/csrc/distributed/c10d/reducer.cpp
+++ b/torch/csrc/distributed/c10d/reducer.cpp
@@ -64,7 +64,9 @@
 
   // Initialize variable bucketing.
   // This can be reinitialized later after capturing runtime information.
+  std::unique_lock<std::mutex> lock(this->mutex_);
   initialize_buckets(std::move(bucket_indices));
+  lock.unlock();
 
   // All variables are expected to have their `grad_fn` set to the gradient
   // accumulation function (since they are leafs in the autograd graph).
@@ -315,56 +317,66 @@
   const auto length = replica.lengths[bucket_index.intra_bucket_index];
   auto& bucket_view = replica.bucket_views[bucket_index.intra_bucket_index];
 
-  // Copy contents of gradient tensor to bucket tensor.
-  // If the gradient is not set, we assume it wasn't computed
-  // as part of the current backwards pass, and zero the part
-  // of the bucket it would otherwise hold.
   runGradCallbackForVariable(variable, [&](auto& grad) {
     if (grad.defined()) {
-      // Ensure that the gradient type matches the bucket type.
-      TORCH_CHECK(
-          grad.options().type_equal(bucket_view.options()),
-          "Expected ",
-          bucket_view.toString(),
-          ", got ",
-          grad.toString());
-      // Assert that the grad tensor and the bucket don't share storage.
-      // If they did, we could avoid the copy altogether.
-      // The reason for not doing this is that existing code calls
-      // `detach_` from `zero_grad`, which is incompatible with views.
-      TORCH_INTERNAL_ASSERT(!grad.is_alias_of(bucket_view));
-      TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
-      TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
-      // AccumulateGrad doesn't HAVE to obey the grad layout contract.
-      // The penalty for disobedience is reduced performance, not numerical
-      // death. Warnings here help diagnose poor DDP performance.
-      if (grad.strides() != bucket_view.strides()) {
-        TORCH_WARN_ONCE(
-            "Grad strides do not match bucket view strides. "
-            "This may indicate grad was not created according to the "
-            "gradient layout contract, or that the param's strides "
-            "changed since DDP was constructed.  This is not an error, "
-            "but may impair performance.\n"
-            "grad.sizes() = ",
-            grad.sizes(),
-            ", strides() = ",
-            grad.strides(),
-            "\n",
-            "bucket_view.sizes() = ",
-            bucket_view.sizes(),
-            ", strides() = ",
-            bucket_view.strides());
-      }
-      // See Note [DDP Communication Hook]
-      if (comm_hook_ == nullptr) {
-        // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
-        auto wrapped =
-            c10::scalar_to_tensor(double(1.) / process_group_->getSize());
-        wrapped.unsafeGetTensorImpl()->set_wrapped_number(true);
-        // Divides while copying into the bucket view.
-        at::native::mul_out(bucket_view, grad, wrapped);
+      // Copy grad to bucket view buffer if grad and bucket_view are pointing
+      // to different storages, and then let grad point to bucket_view
+      // for saving memory and avoiding copies in subsquent iterations.
+      // In most cases, the copy is needed only at first
+      // iteration, there will be no copies in subsquent iterations.
+      // In rare cases, if users explicitly set grad to be None after every
+      // iteration, then it needs to copy grad to bucket_view in every
+      // iteration.
+      if (!grad.is_alias_of(bucket_view)) {
+        // Ensure that the gradient type matches the bucket type.
+        TORCH_CHECK(
+            grad.options().type_equal(bucket_view.options()),
+            "Expected ",
+            bucket_view.toString(),
+            ", got ",
+            grad.toString());
+        TORCH_INTERNAL_ASSERT(grad.device() == bucket_view.device());
+        TORCH_INTERNAL_ASSERT(grad.numel() == bucket_view.numel());
+        // AccumulateGrad doesn't HAVE to obey the grad layout contract.
+        // The penalty for disobedience is reduced performance, not numerical
+        // death. Warnings here help diagnose poor DDP performance.
+        if (grad.strides() != bucket_view.strides()) {
+          TORCH_WARN_ONCE(
+              "Grad strides do not match bucket view strides. "
+              "This may indicate grad was not created according to the "
+              "gradient layout contract, or that the param's strides "
+              "changed since DDP was constructed.  This is not an error, "
+              "but may impair performance.\n"
+              "grad.sizes() = ",
+              grad.sizes(),
+              ", strides() = ",
+              grad.strides(),
+              "\n",
+              "bucket_view.sizes() = ",
+              bucket_view.sizes(),
+              ", strides() = ",
+              bucket_view.strides());
+        }
+        // See Note [DDP Communication Hook]
+        if (comm_hook_ == nullptr) {
+          // imitates wrapped_scalar_tensor in ATen/native/BinaryOps.cpp
+          auto wrapped =
+              c10::scalar_to_tensor(double(1.) / process_group_->getSize());
+          wrapped.unsafeGetTensorImpl()->set_wrapped_number(true);
+          // Divides while copying into the bucket view.
+          at::native::mul_out(bucket_view, grad, wrapped);
+        } else {
+          bucket_view.copy_(grad);
+        }
+        // Let grad point to bucket_view buffer.
+        grad = bucket_view;
+        // The grad is modified and need to be written back.
+        return true;
       } else {
-        bucket_view.copy_(grad);
+        // If grad and bucket view point to the same storage, no need to copy
+        if (comm_hook_ == nullptr) {
+          bucket_view.div_(process_group_->getSize());
+        }
       }
     } else {
       bucket_view.zero_();
@@ -552,20 +564,10 @@
     const c10::Stream currentStream =
         guard.getStream(replica.contents.device());
     torch::autograd::Engine::get_default_engine().queue_callback([=] {
-      std::unique_lock<std::mutex> lock(this->mutex_);
+      std::lock_guard<std::mutex> lock(this->mutex_);
       // Run callback with the current stream
       c10::OptionalStreamGuard currentStreamGuard{currentStream};
       this->finalize_backward();
-      // Rebuild bucket if this is the first time to rebuild
-      if (!rebuilt_params_.empty()) {
-        auto rebuilt_bucket_indices = rebuildBuckets();
-        // Unlock before initialize_buckets() as initialize_buckets() requires a
-        // lock, it could result in self deadlock without unlocking here.
-        lock.unlock();
-        initialize_buckets(std::move(rebuilt_bucket_indices));
-      } else {
-        lock.unlock();
-      }
     });
   }
 }
@@ -613,7 +615,16 @@
 
 void Reducer::initialize_buckets(
     std::vector<std::vector<size_t>> bucket_indices) {
-  std::lock_guard<std::mutex> lock(mutex_);
+  // If initialize_buckets is called inside DDP constructor, then
+  // it does not matter rpc context ptr is nullptr or not, as grad
+  // will not be mutated.
+  // If initialize_buckets is called during training loop, e.g, inside
+  // rebuild_buckets(), since grad could be mutated and be pointed to
+  // bucket_view, then it needs to check rpc context ptr is nullptr or not,
+  // If rpc context ptr is nullptr, mutate variable.grad(); otherwise,
+  // mutate grad in rpc context.
+  using torch::distributed::autograd::ThreadLocalDistAutogradContext;
+  this->rpc_context_.set(ThreadLocalDistAutogradContext::getContextPtr());
 
   // This shouldn't be called if we're expecting autograd hooks to fire.
   TORCH_CHECK(
@@ -697,7 +708,6 @@
 
         // Allocate bucket contents tensor.
         replica.contents = at::empty({static_cast<long>(offset)}, options);
-
         // Note:  "Gradient Layout Contract"
         //
         // Here, create views into the contents tensor for each variable's grad.
@@ -735,7 +745,7 @@
         // metadata.  Checking just once won't catch if someone messes with
         // param layouts over time, but not messing with params after DDP
         // construction is already a documented constraint.
-        initialize_bucketviews(replica, replica.contents);
+        initialize_bucket_views(replica, replica.contents, true);
       }
 
       // Add bucket replica to enclosing bucket.
@@ -761,29 +771,61 @@
 }
 
 // (see Note:  "Gradient Layout Contract" in initialize_buckets).
-void Reducer::initialize_bucketviews(
+void Reducer::initialize_bucket_views(
     Reducer::BucketReplica& replica,
-    at::Tensor& contents) {
+    at::Tensor& contents,
+    bool copy_to_bucket_view) {
   for (size_t i = 0; i < replica.variables.size(); i++) {
-    const auto& v = replica.variables[i];
+    auto& v = replica.variables[i];
     const auto offset = replica.offsets[i];
     const auto length = replica.lengths[i];
+    at::Tensor bucket_view;
     if (v.is_non_overlapping_and_dense()) {
       // If the param's memory is dense, match its layout, anticipating
       // the autograd engine (AccumulateGrad) will also create gradients
       // matching its layout.
-      replica.bucket_views.push_back(
-          contents.as_strided(v.sizes(), v.strides(), offset));
+      bucket_view = contents.as_strided(v.sizes(), v.strides(), offset);
     } else {
       // Fall back to a C-style contiguous view, again anticipating
       // AccumulateGrad will do the same when stashing grads for non-dense
       // params.
-      replica.bucket_views.push_back(
-          contents.narrow(0, offset, length).view(v.sizes()));
+      bucket_view = contents.narrow(0, offset, length).view(v.sizes());
     }
+    replica.bucket_views.push_back(bucket_view);
+    // There are three cases to handle:
+    // 1. initialize_bucket_views could be called inside communication hook,
+    // bucket_view has the updated results in new tensor, just let grad point to
+    // bucket_view, copy_to_bucket_view is false in this case.
+    // 2. initialize_bucket_views could be called inside initialize_buckets when
+    // rebuild_buckets, if grad has already been defined/calculated in previous
+    // iteration, old grad needs to be copied into new bucket_view
+    // and let grad point to the new bucket_view,
+    // copy_to_bucket_view is true in this case.
+    // 3. initialize_bucket_views could be called inside initialize_buckets
+    // during construction. copy_to_bucket_view is true in this case. But mostly
+    // grads are not defined during construction time, when grad is not defined,
+    // do not let grad point to bucket_view, because grads should be kept as
+    // being undefined for globally unused parameters.
+    runGradCallbackForVariable(v, [&](auto& grad) {
+      if (grad.defined() && !grad.is_alias_of(bucket_view)) {
+        if (copy_to_bucket_view) {
+          bucket_view.copy_(grad);
+        }
+        grad = bucket_view;
+        // The grad is modefied and needs to be written back.
+        return true;
+      }
+      // The grad is not modified and does not need to be written back.
+      return false;
+    });
   }
 }
 
+void Reducer::prepare_forward() {
+  std::lock_guard<std::mutex> lock(mutex_);
+  rebuild_buckets();
+}
+
 // Traverse the autograd graph starting at the specified output.
 // All parameters for which we have a pointer to their gradient accumulation
 // functions, but don't show up in the autograd graph will be marked ready for
@@ -931,13 +973,14 @@
       runGradCallbackForVariable(variable, [&](auto& grad) {
         // If a parameter is globally unused, we keep its grad untouched.
         if (!global_unused) {
+          // If grad is globally used but locally unused, let grad point to
+          // bucket_view
           if (!grad.defined()) {
-            // Creates grad according to the "Gradient Layout Contract"
-            // (see torch/csrc/grad/AccumulateGrad.h)
-            grad = torch::autograd::utils::clone_obey_contract(
-                bucket_view, variable);
+            grad = bucket_view;
           } else {
-            grad.copy_(bucket_view);
+            TORCH_INTERNAL_ASSERT(
+                grad.is_alias_of(bucket_view),
+                "Grad should have been pointed to bucket_view if grad is defined");
           }
           // The grad is modified and needs to be written back.
           return true;
@@ -987,7 +1030,7 @@
           // Reinitialize bucket_views with the future_result by following
           // the same logic in `inititalize_buckets`.
           bucket.replicas[i].bucket_views.clear();
-          initialize_bucketviews(bucket.replicas[i], future_result[i]);
+          initialize_bucket_views(bucket.replicas[i], future_result[i], false);
         }
       }
     }
@@ -1115,7 +1158,11 @@
   }
 }
 
-std::vector<std::vector<size_t>> Reducer::rebuildBuckets() {
+void Reducer::rebuild_buckets() {
+  if (rebuilt_params_.empty()) {
+    return;
+  }
+
   TORCH_INTERNAL_ASSERT(
       rebuilt_params_.size() == rebuilt_param_indices_.size(),
       "rebuilt parameter tensors size is not same as rebuilt parameter indices size.");
@@ -1141,7 +1188,7 @@
   rebuilt_params_.clear();
   rebuilt_param_indices_.clear();
 
-  return rebuilt_bucket_indices;
+  initialize_buckets(std::move(rebuilt_bucket_indices));
 }
 
 // See Note [DDP Communication Hook]
diff --git a/torch/csrc/distributed/c10d/reducer.h b/torch/csrc/distributed/c10d/reducer.h
index 8a59bfc..d1b1513 100644
--- a/torch/csrc/distributed/c10d/reducer.h
+++ b/torch/csrc/distributed/c10d/reducer.h
@@ -34,11 +34,13 @@
 
   ~Reducer() noexcept(false);
 
-  // To (re-)initialize bucket assignment, pass a list of buckets, each
-  // of which is specified by a list of indices in the variables list.
-  // This function performs validation that the variables within a bucket
-  // all live on the same device and have the same dimensionality.
-  void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
+  // This funcation is called before forward compuation, e.g.
+  // rebuild_buckets.
+  // It may allocate new buckets before deallocating old buckets
+  // inside rebuild_buckets. To save peak memory usage,
+  // call rebuild_buckets before the peak memory usage increases
+  // during forward computation.
+  void prepare_forward();
 
   // This function is called when the forward function has produced an output,
   // and the user wishes to reduce gradients in the backwards pass.
@@ -123,9 +125,16 @@
 
   void finalize_backward();
 
+  // To (re-)initialize bucket assignment, pass a list of buckets, each
+  // of which is specified by a list of indices in the variables list.
+  // This function performs validation that the variables within a bucket
+  // all live on the same device and have the same dimensionality.
+  void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
+
   // Broadcast rebuilt buckets from rank 0 to other ranks before initializing
   // the buckets
   void sync_bucket_indices(std::vector<std::vector<size_t>>& bucket_indices);
+
   // Rebuild buckets based on rebuilt_params_ and rebuilt_param_indices_
   // TODO this function makes broadcast communication call and
   // could be overlapped with next forward() call, thus
@@ -135,7 +144,7 @@
   // and parameter indices order may change more frequently.
   // For find_unused_parameters = false case, buckets are only rebuilt once,
   // the performance cost is negligible.
-  std::vector<std::vector<size_t>> rebuildBuckets();
+  void rebuild_buckets();
 
   using GradCallback =
       torch::distributed::autograd::DistAutogradContext::GradCallback;
@@ -189,11 +198,19 @@
   // This function is called inside `initialize_buckets` and
   // `finalize_backward`. The function call in `initialize_bucket` creates views
   // into the contents tensor for each variable's grad. Views serve as entry
-  // points to copy_ each grad's data in/out of the flat contents tensor. The
-  // function call in `finalize_backward` happens only if DDP communication hook
-  // was registered to recrate views with the result of `future_work`. Before
-  // `finalize_backward` call, views must be cleared.
-  void initialize_bucketviews(BucketReplica& replica, at::Tensor& contents);
+  // points to refer to each grad's data of the flat contents tensor. When it is
+  // called inside 'initialize_buckets', copy_to_bucket_view is true, meaning grad
+  // needs to be copied into bucket_view.
+  // The function call in `finalize_backward` happens only if DDP communication
+  // hook was registered to recrate views with the result of `future_work`.
+  // Before `finalize_backward` call, views must be cleared. In this case,
+  // copy_to_bucket_view is false, meaning grad does not need to be copied into
+  // bucket_view, as grad has already been mutated in bucket_view, just let grad
+  // point to bucket_view here.
+  void initialize_bucket_views(
+      BucketReplica& replica,
+      at::Tensor& contents,
+      bool copy_to_bucket_view);
 
   // A bucket holds N bucket replicas (1 per model replica).
   //
diff --git a/torch/nn/parallel/distributed.py b/torch/nn/parallel/distributed.py
index b11d327..f8200fa 100644
--- a/torch/nn/parallel/distributed.py
+++ b/torch/nn/parallel/distributed.py
@@ -571,6 +571,8 @@
         if self.require_forward_param_sync:
             self._sync_params()
 
+        self.reducer.prepare_forward()
+
         if self.device_ids:
             inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
             if len(self.device_ids) == 1: