Add default_stream() and enhance current_stream() (#16200)
Summary:
Closes #16156
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16200
Differential Revision: D13747455
Pulled By: mrshenli
fbshipit-source-id: 00c0d5f341c3ac7a757bdb4631a17e11fbc6d3ec
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 79a795b..bbd5536 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1422,6 +1422,67 @@
def test_cuda_synchronize(self):
torch.cuda.synchronize()
+ @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+ @skipIfRocm
+ def test_current_stream(self):
+ d0 = torch.device('cuda:0')
+ d1 = torch.device('cuda:1')
+
+ s0 = torch.cuda.current_stream()
+ s1 = torch.cuda.current_stream(device=1)
+ s2 = torch.cuda.current_stream(device=0)
+
+ self.assertEqual(d0, s0.device)
+ self.assertEqual(d1, s1.device)
+ self.assertEqual(d0, s2.device)
+ self.assertEqual(s0, s2)
+
+ with torch.cuda.device(d1):
+ s0 = torch.cuda.current_stream()
+ s1 = torch.cuda.current_stream(1)
+ s2 = torch.cuda.current_stream(d0)
+
+ self.assertEqual(d1, s0.device)
+ self.assertEqual(d1, s1.device)
+ self.assertEqual(d0, s2.device)
+ self.assertEqual(s0, s1)
+
+ with self.assertRaisesRegex(ValueError,
+ "Expected a cuda device, but got: cpu"):
+ torch.cuda.current_stream(torch.device('cpu'))
+
+ @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+ @skipIfRocm
+ def test_default_stream(self):
+ d0 = torch.device('cuda:0')
+ d1 = torch.device('cuda:1')
+
+ with torch.cuda.device(d0):
+ s0 = torch.cuda.default_stream()
+
+ with torch.cuda.device(d1):
+ s1 = torch.cuda.default_stream()
+
+ s2 = torch.cuda.default_stream(device=0)
+ s3 = torch.cuda.default_stream(d1)
+
+ self.assertEqual(d0, s0.device)
+ self.assertEqual(d1, s1.device)
+ self.assertEqual(d0, s2.device)
+ self.assertEqual(d1, s3.device)
+ self.assertEqual(s0, s2)
+ self.assertEqual(s1, s3)
+
+ with torch.cuda.device(d0):
+ self.assertEqual(torch.cuda.current_stream(), s0)
+
+ with torch.cuda.device(d1):
+ self.assertEqual(torch.cuda.current_stream(), s1)
+
+ with self.assertRaisesRegex(ValueError,
+ "Expected a cuda device, but got: cpu"):
+ torch.cuda.default_stream(torch.device('cpu'))
+
@skipIfRocm
def test_streams(self):
default_stream = torch.cuda.current_stream()
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 4408612..81ab835 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -66,11 +66,25 @@
END_HANDLE_TH_ERRORS
}
-
-PyObject * THCPModule_getCurrentStream_wrap(PyObject *self)
-{
+PyObject * THCPModule_getCurrentStream_wrap(
+ PyObject * /* unused */, PyObject *device_index) {
HANDLE_TH_ERRORS
- return PyLong_FromUnsignedLongLong(at::cuda::getCurrentCUDAStream().pack());
+ THPUtils_assert(
+ THPUtils_checkLong(device_index), "invalid argument to getCurrentStream");
+ int64_t device = THPUtils_unpackLong(device_index);
+ return PyLong_FromUnsignedLongLong(
+ at::cuda::getCurrentCUDAStream(device).pack());
+ END_HANDLE_TH_ERRORS
+}
+
+PyObject * THCPModule_getDefaultStream_wrap(
+ PyObject * /* unused */, PyObject *device_index) {
+ HANDLE_TH_ERRORS
+ THPUtils_assert(
+ THPUtils_checkLong(device_index), "invalid argument to getDefaultStream");
+ int64_t device = THPUtils_unpackLong(device_index);
+ return PyLong_FromUnsignedLongLong(
+ at::cuda::getDefaultCUDAStream(device).pack());
END_HANDLE_TH_ERRORS
}
@@ -412,7 +426,10 @@
{"_cuda_setDevice", (PyCFunction)THCPModule_setDevice_wrap, METH_O, nullptr},
{"_cuda_getDevice", (PyCFunction)THCPModule_getDevice_wrap, METH_NOARGS, nullptr},
{"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
- {"_cuda_getCurrentStream", (PyCFunction)THCPModule_getCurrentStream_wrap, METH_NOARGS, nullptr},
+ {"_cuda_getCurrentStream",
+ (PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
+ {"_cuda_getDefaultStream",
+ (PyCFunction)THCPModule_getDefaultStream_wrap, METH_O, nullptr},
{"_cuda_getCurrentBlasHandle", (PyCFunction)THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr},
{"_cuda_setStream", (PyCFunction)THCPModule_setStream_wrap, METH_O, nullptr},
{"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, nullptr},
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 8015b20..4f4519f 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -348,10 +348,32 @@
return torch._C._cuda_synchronize()
-def current_stream():
- r"""Returns a currently selected :class:`Stream`."""
+def current_stream(device=None):
+ r"""Returns the currently selected :class:`Stream` for a given device.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ the currently selected :class:`Stream` for the current device, given
+ by :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ (default).
+ """
_lazy_init()
- return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream())
+ return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream(
+ _get_device_index(device, optional=True)))
+
+
+def default_stream(device=None):
+ r"""Returns the default :class:`Stream` for a given device.
+
+ Arguments:
+ device (torch.device or int, optional): selected device. Returns
+ the default :class:`Stream` for the current device, given by
+ :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+ (default).
+ """
+ _lazy_init()
+ return torch.cuda.Stream(_cdata=torch._C._cuda_getDefaultStream(
+ _get_device_index(device, optional=True)))
def current_blas_handle():