add torch.cuda.synchronize(device=None) (#19573)

Summary:
fixes https://github.com/pytorch/pytorch/issues/19509
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19573

Differential Revision: D15045730

Pulled By: ezyang

fbshipit-source-id: 732721b4b360fc4348ca7c87d4cd1386e7651bdd
diff --git a/test/test_cuda.py b/test/test_cuda.py
index ad52d9b..88ae5b7 100644
--- a/test/test_cuda.py
+++ b/test/test_cuda.py
@@ -1497,6 +1497,21 @@
 
     def test_cuda_synchronize(self):
         torch.cuda.synchronize()
+        torch.cuda.synchronize('cuda')
+        torch.cuda.synchronize('cuda:0')
+        torch.cuda.synchronize(0)
+        torch.cuda.synchronize(torch.device('cuda:0'))
+
+        if TEST_MULTIGPU:
+            torch.cuda.synchronize('cuda:1')
+            torch.cuda.synchronize(1)
+            torch.cuda.synchronize(torch.device('cuda:1'))
+
+        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
+            torch.cuda.synchronize(torch.device("cpu"))
+
+        with self.assertRaisesRegex(ValueError, "Expected a cuda device, but"):
+            torch.cuda.synchronize("cpu")
 
     @unittest.skipIf(not TEST_MULTIGPU, "detected only one GPU")
     def test_current_stream(self):
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index c9e4076..94fa9b6 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -271,7 +271,7 @@
     Arguments:
         device (torch.device or int, optional): device for which to return the
             name. This function is a no-op if this argument is a negative
-            integer. Uses the current device, given by :meth:`~torch.cuda.current_device`,
+            integer. It uses the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
     """
     return get_device_properties(device).name
@@ -283,8 +283,8 @@
     Arguments:
         device (torch.device or int, optional): device for which to return the
             device capability. This function is a no-op if this argument is
-            a negative integer. Uses the current device, given by
-            :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            a negative integer. It uses the current device, given by
+            :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
             (default).
 
     Returns:
@@ -339,7 +339,7 @@
 
 
 def device_count():
-    """Returns the number of GPUs available."""
+    r"""Returns the number of GPUs available."""
     if is_available():
         return torch._C._cuda_getDeviceCount()
     else:
@@ -352,10 +352,17 @@
     return torch._C._cuda_getDevice()
 
 
-def synchronize():
-    r"""Waits for all kernels in all streams on current device to complete."""
+def synchronize(device=None):
+    r"""Waits for all kernels in all streams on a CUDA device to complete.
+
+    Arguments:
+        device (torch.device or int, optional): device for which to synchronize.
+            It uses the current device, given by :func:`~torch.cuda.current_device`,
+            if :attr:`device` is ``None`` (default).
+    """
     _lazy_init()
-    return torch._C._cuda_synchronize()
+    with torch.cuda.device(device):
+        return torch._C._cuda_synchronize()
 
 
 def ipc_collect():
@@ -377,7 +384,7 @@
     Arguments:
         device (torch.device or int, optional): selected device. Returns
             the currently selected :class:`Stream` for the current device, given
-            by :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
             (default).
     """
     _lazy_init()
@@ -391,7 +398,7 @@
     Arguments:
         device (torch.device or int, optional): selected device. Returns
             the default :class:`Stream` for the current device, given by
-            :meth:`~torch.cuda.current_device`, if :attr:`device` is ``None``
+            :func:`~torch.cuda.current_device`, if :attr:`device` is ``None``
             (default).
     """
     _lazy_init()
@@ -411,7 +418,7 @@
     `nvidia-smi`.
 
     .. note::
-        :meth:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
+        :func:`~torch.cuda.empty_cache` doesn't increase the amount of GPU
         memory available for PyTorch. See :ref:`cuda-memory-management` for
         more details about GPU memory management.
     """
@@ -425,7 +432,7 @@
 
     Arguments:
         device (torch.device or int, optional): selected device. Returns
-            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
 
     .. note::
@@ -450,7 +457,7 @@
 
     Arguments:
         device (torch.device or int, optional): selected device. Returns
-            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
 
     .. note::
@@ -469,7 +476,7 @@
 
     Arguments:
         device (torch.device or int, optional): selected device. Returns
-            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
 
     .. note::
@@ -486,7 +493,7 @@
 
     Arguments:
         device (torch.device or int, optional): selected device. Returns
-            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
 
     .. note::
@@ -509,7 +516,7 @@
 
     Arguments:
         device (torch.device or int, optional): selected device. Returns
-            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
 
     .. note::
@@ -528,7 +535,7 @@
 
     Arguments:
         device (torch.device or int, optional): selected device. Returns
-            statistic for the current device, given by :meth:`~torch.cuda.current_device`,
+            statistic for the current device, given by :func:`~torch.cuda.current_device`,
             if :attr:`device` is ``None`` (default).
 
     .. note::
diff --git a/torch/cuda/__init__.pyi b/torch/cuda/__init__.pyi
index be85475..03da711 100644
--- a/torch/cuda/__init__.pyi
+++ b/torch/cuda/__init__.pyi
@@ -26,6 +26,7 @@
 def check_error(res: int) -> None: ...
 def device_count() -> int: ...
 def empty_cache() -> None: ...
+def synchronize(device: _device_t) -> None: ...
 def set_device(device: _device_t) -> None: ...
 def get_device_capability(device: Optional[_device_t]=...) -> Tuple[int, int]: ...
 def get_device_name(device: Optional[_device_t]=...) -> str: ...