Add default_stream() and enhance current_stream() (#16200)

Summary:
Closes #16156
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16200

Differential Revision: D13747455

Pulled By: mrshenli

fbshipit-source-id: 00c0d5f341c3ac7a757bdb4631a17e11fbc6d3ec
diff --git a/test/test_cuda.py b/test/test_cuda.py
index 79a795b..bbd5536 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1422,6 +1422,67 @@
     def test_cuda_synchronize(self):
         torch.cuda.synchronize()
 
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @skipIfRocm
+    def test_current_stream(self):
+        d0 = torch.device('cuda:0')
+        d1 = torch.device('cuda:1')
+
+        s0 = torch.cuda.current_stream()
+        s1 = torch.cuda.current_stream(device=1)
+        s2 = torch.cuda.current_stream(device=0)
+
+        self.assertEqual(d0, s0.device)
+        self.assertEqual(d1, s1.device)
+        self.assertEqual(d0, s2.device)
+        self.assertEqual(s0, s2)
+
+        with torch.cuda.device(d1):
+            s0 = torch.cuda.current_stream()
+            s1 = torch.cuda.current_stream(1)
+            s2 = torch.cuda.current_stream(d0)
+
+        self.assertEqual(d1, s0.device)
+        self.assertEqual(d1, s1.device)
+        self.assertEqual(d0, s2.device)
+        self.assertEqual(s0, s1)
+
+        with self.assertRaisesRegex(ValueError,
+                                    "Expected a cuda device, but got: cpu"):
+            torch.cuda.current_stream(torch.device('cpu'))
+
+    @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
+    @skipIfRocm
+    def test_default_stream(self):
+        d0 = torch.device('cuda:0')
+        d1 = torch.device('cuda:1')
+
+        with torch.cuda.device(d0):
+            s0 = torch.cuda.default_stream()
+
+        with torch.cuda.device(d1):
+            s1 = torch.cuda.default_stream()
+
+        s2 = torch.cuda.default_stream(device=0)
+        s3 = torch.cuda.default_stream(d1)
+
+        self.assertEqual(d0, s0.device)
+        self.assertEqual(d1, s1.device)
+        self.assertEqual(d0, s2.device)
+        self.assertEqual(d1, s3.device)
+        self.assertEqual(s0, s2)
+        self.assertEqual(s1, s3)
+
+        with torch.cuda.device(d0):
+            self.assertEqual(torch.cuda.current_stream(), s0)
+
+        with torch.cuda.device(d1):
+            self.assertEqual(torch.cuda.current_stream(), s1)
+
+        with self.assertRaisesRegex(ValueError,
+                                    "Expected a cuda device, but got: cpu"):
+            torch.cuda.default_stream(torch.device('cpu'))
+
     @skipIfRocm
     def test_streams(self):
         default_stream = torch.cuda.current_stream()
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 4408612..81ab835 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -66,11 +66,25 @@
   END_HANDLE_TH_ERRORS
 }
 
-
-PyObject * THCPModule_getCurrentStream_wrap(PyObject *self)
-{
+PyObject * THCPModule_getCurrentStream_wrap(
+    PyObject * /* unused */, PyObject *device_index) {
   HANDLE_TH_ERRORS
-  return PyLong_FromUnsignedLongLong(at::cuda::getCurrentCUDAStream().pack());
+  THPUtils_assert(
+    THPUtils_checkLong(device_index), "invalid argument to getCurrentStream");
+  int64_t device = THPUtils_unpackLong(device_index);
+  return PyLong_FromUnsignedLongLong(
+    at::cuda::getCurrentCUDAStream(device).pack());
+  END_HANDLE_TH_ERRORS
+}
+
+PyObject * THCPModule_getDefaultStream_wrap(
+    PyObject * /* unused */, PyObject *device_index) {
+  HANDLE_TH_ERRORS
+  THPUtils_assert(
+    THPUtils_checkLong(device_index), "invalid argument to getDefaultStream");
+  int64_t device = THPUtils_unpackLong(device_index);
+  return PyLong_FromUnsignedLongLong(
+    at::cuda::getDefaultCUDAStream(device).pack());
   END_HANDLE_TH_ERRORS
 }
 
@@ -412,7 +426,10 @@
   {"_cuda_setDevice",   (PyCFunction)THCPModule_setDevice_wrap,   METH_O,       nullptr},
   {"_cuda_getDevice",   (PyCFunction)THCPModule_getDevice_wrap,   METH_NOARGS,  nullptr},
   {"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
-  {"_cuda_getCurrentStream", (PyCFunction)THCPModule_getCurrentStream_wrap, METH_NOARGS, nullptr},
+  {"_cuda_getCurrentStream",
+    (PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
+  {"_cuda_getDefaultStream",
+    (PyCFunction)THCPModule_getDefaultStream_wrap, METH_O, nullptr},
   {"_cuda_getCurrentBlasHandle", (PyCFunction)THCPModule_getCurrentBlasHandle_wrap, METH_NOARGS, nullptr},
   {"_cuda_setStream",    (PyCFunction)THCPModule_setStream_wrap,  METH_O, nullptr},
   {"_cuda_isDriverSufficient", (PyCFunction)THCPModule_isDriverSufficient, METH_NOARGS, nullptr},
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index 8015b20..4f4519f 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -348,10 +348,32 @@
     return torch._C._cuda_synchronize()
 
 
-def current_stream():
-    r"""Returns a currently selected :class:`Stream`."""
+def current_stream(device=None):
+    r"""Returns the currently selected :class:`Stream` for a given device.
+
+    Arguments:
+        device (torch.device or int, optional): selected device. Returns
+            the currently selected :class:`Stream` for the current device, given
+            by :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            (default).
+    """
     _lazy_init()
-    return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream())
+    return torch.cuda.Stream(_cdata=torch._C._cuda_getCurrentStream(
+        _get_device_index(device, optional=True)))
+
+
+def default_stream(device=None):
+    r"""Returns the default :class:`Stream` for a given device.
+
+    Arguments:
+        device (torch.device or int, optional): selected device. Returns
+            the default :class:`Stream` for the current device, given by
+            :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            (default).
+    """
+    _lazy_init()
+    return torch.cuda.Stream(_cdata=torch._C._cuda_getDefaultStream(
+        _get_device_index(device, optional=True)))
 
 
 def current_blas_handle():