[CUDA graphs][BC-breaking] Removes post-backward syncs on default stream (#60421)

Summary:
Before https://github.com/pytorch/pytorch/pull/57833, calls to backward() or grad() synced only the calling thread's default stream with autograd leaf streams at the end of backward. This made the following weird pattern safe:
```python
with torch.cuda.stream(s):
    # imagine forward used many streams, so backward leaf nodes may run on many streams
    loss.backward()
# no sync
use grads
```

but a more benign-looking pattern was unsafe:
```python
with torch.cuda.stream(s):
    # imagine forward used a lot of streams, so backward leaf nodes may run on many streams
    loss.backward()
    # backward() syncs the default stream with all the leaf streams, but does not sync s with anything,
    # so counterintuitively (even though we're in the same stream context as backward()!)
    # it is NOT SAFE to use grads here, and there's no easy way to make it safe,
    # unless you manually sync on all the streams you used in forward,
    # or move "use grads" back to default stream outside the context.
    use grads
```
mruberry ngimel and I decided backward() should have the [same user-facing stream semantics as any cuda op](https://pytorch.org/docs/master/notes/cuda.html#stream-semantics-of-backward-passes).** In other words, the weird pattern should be unsafe, and the benign-looking pattern should be safe. Implementationwise, this meant backward() should sync its calling thread's current stream, not default stream, with the leaf streams.

After https://github.com/pytorch/pytorch/pull/57833, backward syncs the calling thread's current stream AND default stream with all leaf streams at the end of backward. The default stream syncs were retained for temporary backward compatibility.

This PR finishes https://github.com/pytorch/pytorch/pull/57833's work by deleting syncs on the default stream.

With this PR, graph-capturing an entire backward() call should be possible (see the [test_graph_grad_scaling diffs](https://github.com/pytorch/pytorch/compare/master...mcarilli:streaming_backwards_remove_default_syncs?expand=1#diff-893b1eea27352f336f4cd832919e48d721e4e90186e63400b8596db6b82e7450R3641-R3642)).

** first paragraph has a formatting error which this PR should also fix.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/60421

Reviewed By: VitalyFedyunin, albanD

Differential Revision: D29342234

Pulled By: ngimel

fbshipit-source-id: 98e6be7fdd8550872f0a78f9a66cb8dfe75abf63
diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst
index d19b460..e26ec8b 100644
--- a/docs/source/notes/cuda.rst
+++ b/docs/source/notes/cuda.rst
@@ -201,10 +201,14 @@
 Stream semantics of backward passes
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-A. Each backward CUDA op runs on the same stream that was used for its corresponding forward op.
+Each backward CUDA op runs on the same stream that was used for its corresponding forward op.
+If your forward pass runs independent ops in parallel on different streams,
+this helps the backward pass exploit that same parallelism.
 
-B. The stream semantics of a backward call with respect to surrounding ops are the same
-as for any other call. More concretely, when calling
+The stream semantics of a backward call with respect to surrounding ops are the same
+as for any other call. The backward pass inserts internal syncs to ensure this even when
+backward ops run on multiple streams as described in the previous paragraph.
+More concretely, when calling
 :func:`autograd.backward<torch.autograd.backward>`,
 :func:`autograd.grad<torch.autograd.grad>`, or
 :meth:`tensor.backward<torch.Tensor.backward>`,
@@ -255,11 +259,26 @@
         initial_grad.record_stream(s)
         loss.backward(gradient=initial_grad)
 
-If your forward pass runs some independent ops in parallel on different streams,
-A. helps the backward pass exploit that same parallelism.
+BC note: Using grads on the default stream
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-The backward call inserts internal syncs as needed to ensure B. holds true even if A.
-makes some backward ops run on assorted side streams.
+In prior versions of Pytorch (1.9 and earlier), the autograd engine always synced
+the default stream with all backward ops, so the following pattern:
+
+    with torch.cuda.stream(s):
+        loss.backward()
+    use grads
+
+was safe as long as ``use grads`` happened on the default stream.
+In present Pytorch, that pattern is no longer safe. If ``backward()``
+and ``use grads`` are in different stream contexts, you must sync the streams:
+
+    with torch.cuda.stream(s):
+        loss.backward()
+    torch.cuda.current_stream().wait_stream(s)
+    use grads
+
+even if ``use grads`` is on the default stream.
 
 .. _CUDA stream: https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#streams
 
diff --git a/test/test_cuda.py b/test/test_cuda.py
index f2af7ab..aa9447c 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1763,46 +1763,6 @@
 
     # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
     @skipIfRocm
-    def test_streaming_backwards_multiple_streams_legacy(self):
-        # Tests calling backward() under a side stream then using a grad
-        # on the default stream without syncing. Right now, this pattern is safe,
-        # but only for BC. In a future PR, this pattern will become unsafe,
-        # a sync will be required, and this test will be deleted in favor of
-        # test_streaming_backward_multiple_streams below.
-        class StreamModel(torch.nn.Module):
-            def __init__(self):
-                super(StreamModel, self).__init__()
-                self.event = torch.cuda.Event()
-                self.stream0 = torch.cuda.Stream()
-                self.stream1 = torch.cuda.Stream()
-
-            def forward(self, x):
-                x0 = x.clone()
-                torch._C._cuda_setStream(self.stream0._cdata)
-                y0 = x0 * 2
-                self.event.record(stream=torch.cuda.current_stream())
-
-                torch._C._cuda_setStream(self.stream1._cdata)
-                y1 = x * 3
-                self.stream1.wait_event(self.event)
-                return y0 + y1
-
-        stream = torch.cuda.Stream()
-
-        def accum_hook(grad):
-            self.assertEqual(torch.cuda.current_stream(), stream)
-
-        with torch.cuda.stream(stream):
-            x = torch.randn(5, 5, device='cuda', requires_grad=True)
-            x.register_hook(accum_hook)
-            torch.cuda.current_stream().wait_stream(stream)
-            model = StreamModel().cuda()
-            model(x).sum().backward()
-
-        self.assertEqual(x.grad, torch.ones_like(x) * 5)
-
-    # Skip the test for ROCm as per https://github.com/pytorch/pytorch/issues/53190
-    @skipIfRocm
     def test_streaming_backwards_multiple_streams(self):
         MultiplyInStream = self._make_multiply_in_stream()
 
@@ -3668,6 +3628,8 @@
                      TEST_WITH_ROCM or
                      int(torch.version.cuda.split(".")[0]) < 11, "CUDA >= 11.0 required for graphs")
     def test_graph_grad_scaling(self):
+        torch.cuda.empty_cache()
+
         scaler = torch.cuda.amp.GradScaler(init_scale=4.)
         g = torch.cuda._Graph()
         s = torch.cuda.Stream()
@@ -3680,15 +3642,13 @@
         s.wait_stream(torch.cuda.current_stream())
         with torch.cuda.stream(s):
             # warmup
-            weight.grad = (scaler.scale(static_grad) * static_input).half().float()
+            loss = (weight.half() * static_input).sum()
+            scaler.scale(loss).backward()
+            opt.zero_grad(set_to_none=True)
             # capture
             g.capture_begin()
-            weight.grad = (scaler.scale(static_grad) * static_input).half().float()
-            # The above simulates a rudimentary backward pass.
-            # TODO: Once full-backward() capture is enabled (see https://github.com/pytorch/pytorch/pull/54227)
-            # change to
-            # loss = (w.half() * static_input).sum()
-            # scaler.scale(loss).backward()
+            loss = (weight.half() * static_input).sum()
+            scaler.scale(loss).backward()
             g.capture_end()
         torch.cuda.current_stream().wait_stream(s)
 
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 6ed2203..f881d3f 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -164,18 +164,13 @@
 //
 // Internally, backward() runs ops (including leaf nodes) on side threads.
 // And streams are thread local. So GraphTask achieves the above semantics by
-//  1. remembering the current and default streams on all active CUDA devices
+//  1. remembering the current streams on all active CUDA devices
 //     in the user-facing thread (aka, the thread that called execute() to
 //     launch the GraphTask)
 //  2. remembering the "leaf streams" (streams each backward leaf node ran on)
 //  3. during exec_post_processing, for each leaf stream, sync the remembered
-//     current and default streams (on the leaf stream's device) with that
+//     current streams (on the leaf stream's device) with that
 //     leaf stream.
-//
-// Syncing default streams (as well as current streams) with leaf streams is
-// done for temporary BC, and is more conservative than the usage guidance
-// (https://pytorch.org/docs/stable/notes/cuda.html) requires.
-// TODO: change 1, 2, 3 to sync only current streams with leaf streams.
 
 int NodeTask::getReentrantDepth() const {
   std::shared_ptr<GraphTask> graph_task = base_.lock();
@@ -591,18 +586,6 @@
       cb_lock.lock();
     }
   }
-
-  // For temporary BC, syncs default streams with caller_current_streams so callback results are also
-  // usable on user-facing default streams after backward()
-  for (const auto& caller_current_stream : caller_current_streams_filtered) {
-    const auto caller_default_stream = *caller_default_streams_[caller_current_stream.device_index()];
-
-    if (caller_current_stream != caller_default_stream) {
-      auto event = c10::Event{c10::DeviceType::CUDA};
-      event.record(caller_current_stream);
-      caller_default_stream.wait(event);
-    }
-  }
 }
 
 void GraphTask::set_exception_without_signal(const std::shared_ptr<Node>& fn) {
@@ -959,7 +942,7 @@
   }
 
   if (will_use_cuda) {
-    // Collects current and default streams for devices where this process has a context,
+    // Collects current streams for devices where this process has a context,
     // so GraphTask::exec_post_processing can sync them with leaf_streams.
     task.stash_current_streams();
   }
@@ -1254,13 +1237,12 @@
   thread_pool_shared_->work_.notify_one();
 }
 
-// Remembers current and default streams on all devices where a context has been created.
+// Remembers current streams on all devices where a context has been created.
 // Only called if Engine::execute detects at least one node runs on a cuda stream.
 void GraphTask::stash_current_streams() {
   const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
   auto num_gpus = guard.deviceCount();
   caller_current_streams_.resize(num_gpus);
-  caller_default_streams_.resize(num_gpus);
   if (num_gpus > 0) {
     for (c10::DeviceIndex idx = 0; idx < num_gpus;  idx++) {
 #ifdef __HIP_PLATFORM_HCC__
@@ -1272,10 +1254,8 @@
       if (at::detail::getCUDAHooks().hasPrimaryContext(idx)) {
 #endif
         caller_current_streams_[idx] = guard.getStream({c10::DeviceType::CUDA, idx});
-        caller_default_streams_[idx] = guard.getDefaultStream({c10::DeviceType::CUDA, idx});
       } else {
         caller_current_streams_[idx] = c10::nullopt;
-        caller_default_streams_[idx] = c10::nullopt;
       }
     }
   }
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index 7c810cf..668b156 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -112,12 +112,11 @@
 
   std::unordered_set<c10::Stream> leaf_streams;
 
-  // Per-device current and default streams of the execute() that called this GraphTask.
+  // Per-device current streams of the execute() that called this GraphTask.
   // These will be synced with leaf_streams in exec_post_processing.
   std::vector<c10::optional<c10::Stream>> caller_current_streams_;
-  std::vector<c10::optional<c10::Stream>> caller_default_streams_;
 
-  // Collects caller_current_streams_ and caller_default_streams_
+  // Collects caller_current_streams_
   void stash_current_streams();
 
   void init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad, uint64_t min_topo_nr);
diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp
index 50c2445..76f2eae 100644
--- a/torch/csrc/distributed/autograd/engine/dist_engine.cpp
+++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp
@@ -263,7 +263,7 @@
   }
 
   if (will_use_cuda) {
-    // Collects current and default streams for devices where this process has a context,
+    // Collects current streams for devices where this process has a context,
     // so graphTask::exec_post_processing can sync them with leaf_streams.
     graphTask->stash_current_streams();
   }