enable warnings on cuda synchronization (#62092)
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp
index 1f781b3..255d798 100644
--- a/c10/cuda/CUDAFunctions.cpp
+++ b/c10/cuda/CUDAFunctions.cpp
@@ -138,5 +138,15 @@
C10_CUDA_CHECK(cudaDeviceSynchronize());
}
+// this function has to be called from callers performing cuda synchronizing
+// operations, to raise proper error or warning
+void warn_or_error_on_sync() {
+ if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_ERROR) {
+ TORCH_CHECK(false, "called a synchronizing CUDA operation");
+ } else if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_WARN) {
+ TORCH_WARN("called a synchronizing CUDA operation");
+ }
+}
+
} // namespace cuda
} // namespace c10
diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h
index 1464999..7a55e9d 100644
--- a/c10/cuda/CUDAFunctions.h
+++ b/c10/cuda/CUDAFunctions.h
@@ -14,7 +14,6 @@
#include <hip/hip_version.h>
#endif
#include <cuda_runtime_api.h>
-
namespace c10 {
namespace cuda {
@@ -35,6 +34,32 @@
C10_CUDA_API void device_synchronize();
+C10_CUDA_API void warn_or_error_on_sync();
+
+enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR };
+
+// this is a holder for c10 global state (similar to at GlobalContext)
+// currently it's used to store cuda synchronization warning state,
+// but can be expanded to hold other related global state, e.g. to
+// record stream usage
+class WarningState {
+ public:
+ void set_sync_debug_mode(SyncDebugMode l) {
+ sync_debug_mode = l;
+ }
+
+ SyncDebugMode get_sync_debug_mode() {
+ return sync_debug_mode;
+ }
+
+ private:
+ SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED;
+};
+
+C10_CUDA_API __inline__ WarningState& warning_state() {
+ static WarningState warning_state_;
+ return warning_state_;
+}
// the subsequent functions are defined in the header because for performance
// reasons we want them to be inline
C10_CUDA_API void __inline__ memcpy_and_sync(
@@ -43,6 +68,10 @@
int64_t nbytes,
cudaMemcpyKind kind,
cudaStream_t stream) {
+ if (C10_UNLIKELY(
+ warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
+ warn_or_error_on_sync();
+ }
#if defined(HIP_VERSION) && (HIP_VERSION >= 301)
C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream));
#else
@@ -52,6 +81,10 @@
}
C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) {
+ if (C10_UNLIKELY(
+ warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) {
+ warn_or_error_on_sync();
+ }
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
}
diff --git a/docs/source/cuda.rst b/docs/source/cuda.rst
index 4f8fb54..d4783c8 100644
--- a/docs/source/cuda.rst
+++ b/docs/source/cuda.rst
@@ -21,12 +21,14 @@
get_device_name
get_device_properties
get_gencode_flags
+ get_sync_debug_mode
init
ipc_collect
is_available
is_initialized
set_device
set_stream
+ set_sync_debug_mode
stream
synchronize
diff --git a/test/test_torch.py b/test/test_torch.py
index b4eaa05..57d5037 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -33,7 +33,7 @@
do_test_dtypes, IS_SANDCASTLE, IS_FBCODE, IS_REMOTE_GPU, load_tests, slowTest,
skipCUDAMemoryLeakCheckIf, BytesIOContext, noarchTest,
skipIfRocm, skipIfNoSciPy, TemporaryFileName, TemporaryDirectoryName,
- wrapDeterministicFlagAPITest, DeterministicGuard, make_tensor)
+ wrapDeterministicFlagAPITest, DeterministicGuard, CudaSyncGuard, make_tensor)
from multiprocessing.reduction import ForkingPickler
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
@@ -4254,6 +4254,65 @@
self.assertEqual(res, expected, atol=0, rtol=0)
+ @onlyCUDA
+ def test_sync_warning(self, device):
+
+ def _sync_raises_helper(f, level):
+ with CudaSyncGuard(level):
+ if level == 1:
+ with self.assertWarnsRegex(UserWarning, "called a synchronizing "):
+ f()
+ elif level == 2:
+ with self.assertRaisesRegex(RuntimeError, "called a synchronizing "):
+ f()
+
+ def _no_sync_helper(f, level):
+ with CudaSyncGuard(level):
+ f()
+
+ def _ind_put_fn(x, ind, val):
+ x[ind] = val
+ return x
+
+ def _ind_get_fn(x, ind):
+ return x[ind]
+
+ def _cond_fn(x):
+ if x: # taking boolean value of a tensor synchronizes
+ return x
+ else:
+ return 2 * x
+
+ # prepare inputs for subsequent ops
+ size = 4
+ x = torch.rand(size, device=device)
+ y = torch.rand((), device=device)
+ ind = torch.randint(size, (3,), device=device)
+ ind_cpu = ind.cpu()
+ repeats = torch.full((1,), 2, device=device)
+ mask = torch.randint(2, (size,), device=device, dtype=bool)
+ expect_no_sync = (lambda: _ind_put_fn(x, mask, 1.),
+ lambda: _ind_put_fn(x, ind, y),
+ lambda: _ind_get_fn(x, ind),
+ lambda: torch.nn.functional.one_hot(ind, num_classes=size),
+ lambda: torch.randperm(20000, device=device),
+ lambda: torch.repeat_interleave(x, 2, output_size=2 * size),
+ lambda: torch.repeat_interleave(x, repeats, output_size=2 * size))
+ expect_sync = (lambda: _ind_put_fn(x, mask, y),
+ lambda: _ind_put_fn(x, ind_cpu, y),
+ lambda: _ind_get_fn(x, mask),
+ lambda: _ind_get_fn(x, ind_cpu),
+ lambda: x.nonzero(),
+ lambda: _cond_fn(y),
+ lambda: torch.nn.functional.one_hot(ind),
+ lambda: torch.repeat_interleave(x, 2),
+ lambda: torch.repeat_interleave(x, repeats))
+ for f, level in product(expect_no_sync, (1, 2)):
+ _no_sync_helper(f, level)
+ for f, level in product(expect_sync, (1, 2)):
+ _sync_raises_helper(f, level)
+
+
@dtypes(*torch.testing.get_all_fp_dtypes())
def test_log_normal(self, device, dtype):
a = torch.tensor([10], dtype=dtype, device=device).log_normal_()
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index 8cc5f8f..c4f07d1 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -777,6 +777,8 @@
def _cuda_setDevice(device: _int) -> None: ...
def _cuda_getDevice() -> _int: ...
def _cuda_getDeviceCount() -> _int: ...
+def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ...
+def _cuda_get_sync_debug_mode() -> _int: ...
def _cuda_sleep(cycles: _int) -> None: ...
def _cuda_synchronize() -> None: ...
def _cuda_ipc_collect() -> None: ...
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 333044f..ce828e9 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -439,6 +439,38 @@
END_HANDLE_TH_ERRORS
}
+PyObject * THCPModule_cudaSetSyncDebugMode(PyObject * _unused, PyObject * arg){
+ HANDLE_TH_ERRORS
+ TORCH_WARN_ONCE("Synchronization debug mode is a prototype feature and does not yet detect all " \
+ "synchronizing operations");
+ THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to set_sync_debug_mode");
+ int64_t debug_mode = THPUtils_unpackLong(arg);
+ TORCH_CHECK(debug_mode >=0 && debug_mode <=2, "invalid value of debug_mode, expected one of 0,1,2");
+ c10::cuda::SyncDebugMode l;
+ switch (debug_mode) {
+ case 0: l = c10::cuda::SyncDebugMode::L_DISABLED; break;
+ case 1: l = c10::cuda::SyncDebugMode::L_WARN; break;
+ case 2: l = c10::cuda::SyncDebugMode::L_ERROR; break;
+ default: l = c10::cuda::SyncDebugMode::L_DISABLED; break; // can't happen
+ }
+ c10::cuda::warning_state().set_sync_debug_mode(l);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject * THCPModule_cudaGetSyncDebugMode(PyObject *self, PyObject *noargs){
+ HANDLE_TH_ERRORS
+ auto debug_mode = c10::cuda::warning_state().get_sync_debug_mode();
+ switch (debug_mode){
+ case c10::cuda::SyncDebugMode::L_DISABLED: return THPUtils_packInt32(0);
+ case c10::cuda::SyncDebugMode::L_WARN: return THPUtils_packInt32(1);
+ case c10::cuda::SyncDebugMode::L_ERROR: return THPUtils_packInt32(2);
+ default: return THPUtils_packInt32(-1); // can't happen
+ }
+ END_HANDLE_TH_ERRORS
+}
+
+
////////////////////////////////////////////////////////////////////////////////
// Cuda module initialization
////////////////////////////////////////////////////////////////////////////////
@@ -575,6 +607,8 @@
{"_cuda_sleep", THCPModule_cudaSleep, METH_O, nullptr},
{"_cuda_lock_mutex", THCPModule_cudaLockMutex, METH_NOARGS, nullptr},
{"_cuda_unlock_mutex", THCPModule_cudaUnlockMutex, METH_NOARGS, nullptr},
+ {"_cuda_set_sync_debug_mode", THCPModule_cudaSetSyncDebugMode, METH_O, nullptr},
+ {"_cuda_get_sync_debug_mode", THCPModule_cudaGetSyncDebugMode, METH_NOARGS, nullptr},
#ifdef USE_NCCL
{"_nccl_version", THCPModule_nccl_version, METH_NOARGS, nullptr},
{"_nccl_unique_id", THCPModule_nccl_unique_id, METH_NOARGS, nullptr},
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index e5bb7dd..8acdce6 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -492,6 +492,37 @@
_lazy_init()
return torch._C._cuda_getCurrentBlasHandle()
+def set_sync_debug_mode(debug_mode: Union[int, str]) -> None:
+ r"""Sets the debug mode for cuda synchronizing operations.
+
+ Args:
+ debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations,
+ if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations.
+
+ Warning:
+ This is an experimental feature, and not all synchronizing operations will trigger warning or error. In
+ particular, operations in torch.distributed and torch.sparse namespaces are not covered yet.
+ """
+
+ _lazy_init()
+ if isinstance(debug_mode, str):
+ if debug_mode == "default":
+ debug_mode = 0
+ elif debug_mode == "warn":
+ debug_mode = 1
+ elif debug_mode == "error":
+ debug_mode = 2
+ else:
+ raise RuntimeError("invalid value of debug_mode, expected one of `default`, `warn`, `error`")
+
+ torch._C._cuda_set_sync_debug_mode(debug_mode)
+
+def get_sync_debug_mode() -> int:
+ r"""Returns current value of debug mode for cuda synchronizing operations."""
+
+ _lazy_init()
+ return torch._C._cuda_get_sync_debug_mode()
+
from .memory import * # noqa: F403
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 05006df..ee3c455 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -500,6 +500,21 @@
def __exit__(self, exception_type, exception_value, traceback):
torch.use_deterministic_algorithms(self.deterministic_restore)
+# Context manager for setting cuda sync debug mode and reset it
+# to original value
+# we are not exposing it to the core because sync debug mode is
+# global and thus not thread safe
+class CudaSyncGuard:
+ def __init__(self, sync_debug_mode):
+ self.mode = sync_debug_mode
+
+ def __enter__(self):
+ self.debug_mode_restore = torch.cuda.get_sync_debug_mode()
+ torch.cuda.set_sync_debug_mode(self.mode)
+
+ def __exit__(self, exception_type, exception_value, traceback):
+ torch.cuda.set_sync_debug_mode(self.debug_mode_restore)
+
# This decorator can be used for API tests that call
# torch.use_deterministic_algorithms(). When the test is finished, it will
# restore the previous deterministic flag setting.