enable warnings on cuda synchronization (#62092)

Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.

For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")

if y:
    print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at  ../c10/cuda/CUDAFunctions.cpp:145.)
  x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at  ../c10/cuda/CUDAFunctions.cpp:145.)
  if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at  ../c10/cuda/CUDAFunctions.cpp:145.)
  torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at  ../c10/cuda/CUDAFunctions.cpp:145.)
  x[mask] = val
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092

Reviewed By: mruberry

Differential Revision: D29968792

Pulled By: ngimel

fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp
index 1f781b3..255d798 100644
--- a/c10/cuda/CUDAFunctions.cpp
+++ b/c10/cuda/CUDAFunctions.cpp
@@ -138,5 +138,15 @@
   C10_CUDA_CHECK(cudaDeviceSynchronize());
 }
 
+// this function has to be called from callers performing cuda synchronizing
+// operations, to raise proper error or warning
+void warn_or_error_on_sync() {
+  if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_ERROR) {
+    TORCH_CHECK(false, "called a synchronizing CUDA operation");
+  } else if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_WARN) {
+    TORCH_WARN("called a synchronizing CUDA operation");
+  }
+}
+
 } // namespace cuda
 } // namespace c10
diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h
index 1464999..7a55e9d 100644
--- a/c10/cuda/CUDAFunctions.h
+++ b/c10/cuda/CUDAFunctions.h
@@ -14,7 +14,6 @@
 #include <hip/hip_version.h>
 #endif
 #include <cuda_runtime_api.h>
-
 namespace c10 {
 namespace cuda {
 
@@ -35,6 +34,32 @@
 
 C10_CUDA_API void device_synchronize();
 
+C10_CUDA_API void warn_or_error_on_sync();
+
+enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
+
+// this is a holder for c10 global state (similar to at GlobalContext)
+// currently it's used to store cuda synchronization warning state,
+// but can be expanded to hold other related global state, e.g. to
+// record stream usage
+class WarningState {
+ public:
+  void set_sync_debug_mode(SyncDebugMode l) {
+    sync_debug_mode = l;
+  }
+
+  SyncDebugMode get_sync_debug_mode() {
+    return sync_debug_mode;
+  }
+
+ private:
+  SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
+};
+
+C10_CUDA_API __inline__ WarningState& warning_state() {
+  static WarningState warning_state_;
+  return warning_state_;
+}
 // the subsequent functions are defined in the header because for performance
 // reasons we want them to be inline
 C10_CUDA_API void __inline__ memcpy_and_sync(
@@ -43,6 +68,10 @@
     int64_t nbytes,
     cudaMemcpyKind kind,
     cudaStream_t stream) {
+  if (C10_UNLIKELY(
+          warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
+    warn_or_error_on_sync();
+  }
 #if defined(HIP_VERSION) && (HIP_VERSION >= 301)
   C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
 #else
@@ -52,6 +81,10 @@
 }
 
 C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
+  if (C10_UNLIKELY(
+          warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
+    warn_or_error_on_sync();
+  }
   C10_CUDA_CHECK(cudaStreamSynchronize(stream));
 }
 
diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst
index 4f8fb54..d4783c8 100644
--- a/docs/source/cuda.rst
+++ b/docs/source/cuda.rst
@@ -21,12 +21,14 @@
     get_device_name
     get_device_properties
     get_gencode_flags
+    get_sync_debug_mode
     init
     ipc_collect
     is_available
     is_initialized
     set_device
     set_stream
+    set_sync_debug_mode
     stream
     synchronize
 
diff --git a/test/test_torch.py b/test/test_torch.py
index b4eaa05..57d5037 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -33,7 +33,7 @@
     do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
     skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest,
     skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
-    wrapDeterministicFlagAPITest, DeterministicGuard, make_tensor)
+    wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, make_tensor)
 from multiprocessing.reduction import ForkingPickler
 from torch.testing._internal.common_device_type import (
     instantiate_device_type_tests,
@@ -4254,6 +4254,65 @@
 
             self.assertEqual(res, expected, atol=0, rtol=0)
 
+    @onlyCUDA
+    def test_sync_warning(self, device):
+
+        def _sync_raises_helper(f, level):
+            with CudaSyncGuard(level):
+                if level == 1:
+                    with self.assertWarnsRegex(UserWarning, "called a synchronizing "):
+                        f()
+                elif level == 2:
+                    with self.assertRaisesRegex(RuntimeError, "called a synchronizing "):
+                        f()
+
+        def _no_sync_helper(f, level):
+            with CudaSyncGuard(level):
+                f()
+
+        def _ind_put_fn(x, ind, val):
+            x[ind] = val
+            return x
+
+        def _ind_get_fn(x, ind):
+            return x[ind]
+
+        def _cond_fn(x):
+            if x:  # taking boolean value of a tensor synchronizes
+                return x
+            else:
+                return 2 * x
+
+        # prepare inputs for subsequent ops
+        size = 4
+        x = torch.rand(size, device=device)
+        y = torch.rand((), device=device)
+        ind = torch.randint(size, (3,), device=device)
+        ind_cpu = ind.cpu()
+        repeats = torch.full((1,), 2, device=device)
+        mask = torch.randint(2, (size,), device=device, dtype=bool)
+        expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
+                          lambda: _ind_put_fn(x, ind, y),
+                          lambda: _ind_get_fn(x, ind),
+                          lambda: torch.nn.functional.one_hot(ind, num_classes=size),
+                          lambda: torch.randperm(20000, device=device),
+                          lambda: torch.repeat_interleave(x, 2, output_size=2 * size),
+                          lambda: torch.repeat_interleave(x, repeats, output_size=2 * size))
+        expect_sync = (lambda: _ind_put_fn(x, mask, y),
+                       lambda: _ind_put_fn(x, ind_cpu, y),
+                       lambda: _ind_get_fn(x, mask),
+                       lambda: _ind_get_fn(x, ind_cpu),
+                       lambda: x.nonzero(),
+                       lambda: _cond_fn(y),
+                       lambda: torch.nn.functional.one_hot(ind),
+                       lambda: torch.repeat_interleave(x, 2),
+                       lambda: torch.repeat_interleave(x, repeats))
+        for f, level in product(expect_no_sync, (1, 2)):
+            _no_sync_helper(f, level)
+        for f, level in product(expect_sync, (1, 2)):
+            _sync_raises_helper(f, level)
+
+
     @dtypes(*torch.testing.get_all_fp_dtypes())
     def test_log_normal(self, device, dtype):
         a = torch.tensor([10], dtype=dtype, device=device).log_normal_()
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 8cc5f8f..c4f07d1 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -777,6 +777,8 @@
 def _cuda_setDevice(device: _int) -> None: ...
 def _cuda_getDevice() -> _int: ...
 def _cuda_getDeviceCount() -> _int: ...
+def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ...
+def _cuda_get_sync_debug_mode() -> _int: ...
 def _cuda_sleep(cycles: _int) -> None: ...
 def _cuda_synchronize() -> None: ...
 def _cuda_ipc_collect() -> None: ...
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 333044f..ce828e9 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -439,6 +439,38 @@
   END_HANDLE_TH_ERRORS
 }
 
+PyObject * THCPModule_cudaSetSyncDebugMode(PyObject * _unused, PyObject * arg){
+  HANDLE_TH_ERRORS
+  TORCH_WARN_ONCE("Synchronization debug mode is a prototype feature and does not yet detect all " \
+  "synchronizing operations");
+  THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to set_sync_debug_mode");
+  int64_t debug_mode = THPUtils_unpackLong(arg);
+  TORCH_CHECK(debug_mode >=0 && debug_mode <=2, "invalid value of debug_mode, expected one of 0,1,2");
+  c10::cuda::SyncDebugMode l;
+  switch (debug_mode) {
+    case 0: l = c10::cuda::SyncDebugMode::L_DISABLED; break;
+    case 1: l = c10::cuda::SyncDebugMode::L_WARN; break;
+    case 2: l = c10::cuda::SyncDebugMode::L_ERROR; break;
+    default: l = c10::cuda::SyncDebugMode::L_DISABLED; break; // can't happen
+  }
+  c10::cuda::warning_state().set_sync_debug_mode(l);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+PyObject * THCPModule_cudaGetSyncDebugMode(PyObject *self, PyObject *noargs){
+  HANDLE_TH_ERRORS
+  auto debug_mode = c10::cuda::warning_state().get_sync_debug_mode();
+  switch (debug_mode){
+    case c10::cuda::SyncDebugMode::L_DISABLED: return THPUtils_packInt32(0);
+    case c10::cuda::SyncDebugMode::L_WARN: return THPUtils_packInt32(1);
+    case c10::cuda::SyncDebugMode::L_ERROR: return THPUtils_packInt32(2);
+    default: return THPUtils_packInt32(-1); // can't happen
+  }
+  END_HANDLE_TH_ERRORS
+}
+
+
 ////////////////////////////////////////////////////////////////////////////////
 // Cuda module initialization
 ////////////////////////////////////////////////////////////////////////////////
@@ -575,6 +607,8 @@
   {"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},
   {"_cuda_lock_mutex",   THCPModule_cudaLockMutex,   METH_NOARGS,  nullptr},
   {"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS,  nullptr},
+  {"_cuda_set_sync_debug_mode", THCPModule_cudaSetSyncDebugMode, METH_O, nullptr},
+  {"_cuda_get_sync_debug_mode", THCPModule_cudaGetSyncDebugMode, METH_NOARGS, nullptr},
 #ifdef USE_NCCL
   {"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr},
   {"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr},
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index e5bb7dd..8acdce6 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -492,6 +492,37 @@
     _lazy_init()
     return torch._C._cuda_getCurrentBlasHandle()
 
+def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
+    r"""Sets the debug mode for cuda synchronizing operations.
+
+    Args:
+        debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
+            if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
+
+    Warning:
+        This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
+        particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
+    """
+
+    _lazy_init()
+    if isinstance(debug_mode, str):
+        if debug_mode == "default":
+            debug_mode = 0
+        elif debug_mode == "warn":
+            debug_mode = 1
+        elif debug_mode == "error":
+            debug_mode = 2
+        else:
+            raise RuntimeError("invalid value of debug_mode, expected one of `default`, `warn`, `error`")
+
+    torch._C._cuda_set_sync_debug_mode(debug_mode)
+
+def get_sync_debug_mode() -> int:
+    r"""Returns current value of debug mode for cuda synchronizing operations."""
+
+    _lazy_init()
+    return torch._C._cuda_get_sync_debug_mode()
+
 
 from .memory import *  # noqa: F403
 
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 05006df..ee3c455 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -500,6 +500,21 @@
     def __exit__(self, exception_type, exception_value, traceback):
         torch.use_deterministic_algorithms(self.deterministic_restore)
 
+# Context manager for setting cuda sync debug mode and reset it
+# to original value
+# we are not exposing it to the core because sync debug mode is
+# global and thus not thread safe
+class CudaSyncGuard:
+    def __init__(self, sync_debug_mode):
+        self.mode = sync_debug_mode
+
+    def __enter__(self):
+        self.debug_mode_restore = torch.cuda.get_sync_debug_mode()
+        torch.cuda.set_sync_debug_mode(self.mode)
+
+    def __exit__(self, exception_type, exception_value, traceback):
+        torch.cuda.set_sync_debug_mode(self.debug_mode_restore)
+
 # This decorator can be used for API tests that call
 # torch.use_deterministic_algorithms().  When the test is finished, it will
 # restore the previous deterministic flag setting.