Re-land "Fix error message for a wrong fork CUDA" (#23209)

Summary:
Re-land https://github.com/pytorch/pytorch/pull/23030
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23209

Differential Revision: D16440000

Pulled By: zhangguanheng66

fbshipit-source-id: e05683275522835a33d5a7e6d76b7e94774e4d98
diff --git a/test/test_multiprocessing.py b/test/test_multiprocessing.py
index c28165c..c72962e 100644
--- a/test/test_multiprocessing.py
+++ b/test/test_multiprocessing.py
@@ -3,6 +3,7 @@
 import os
 import sys
 import time
+import subprocess
 import unittest
 from sys import platform
 
@@ -468,6 +469,27 @@
         p.join()
         self.assertIsInstance(outq.get(), RuntimeError)
 
+    @unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
+    def test_wrong_cuda_fork(self):
+        results = self.run_process_no_exception("""\
+import torch
+from torch.multiprocessing import Process
+def run(rank):
+    torch.cuda.set_device(rank)
+if __name__ == "__main__":
+    size = 2
+    processes = []
+    for rank in range(size):
+        # it would work fine without the line below
+        x = torch.rand(20, 2).cuda()
+        p = Process(target=run, args=(rank,))
+        p.start()
+        processes.append(p)
+    for p in processes:
+        p.join()
+""")
+        self.assertRegex(results[1].decode('ascii'), "Cannot re-initialize CUDA in forked subprocess.")
+
     @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
                      don't support multiprocessing with spawn start method")
     @unittest.skipIf(not TEST_CUDA_IPC, 'CUDA IPC not available')
@@ -755,6 +777,15 @@
         param = Parameter(torch.arange(1., 26, device='cuda').view(5, 5))
         self._test_autograd_sharing(param, mp.get_context('spawn'), is_parameter=True)
 
+    @staticmethod
+    def run_process_no_exception(code):
+        popen = subprocess.Popen(
+            [sys.executable, '-c', code],
+            stdout=subprocess.PIPE,
+            stderr=subprocess.PIPE)
+        pipes = popen.communicate()
+        return pipes
+
     @unittest.skipIf(NO_MULTIPROCESSING_SPAWN, "Disabled for environments that \
                      don't support multiprocessing with spawn start method")
     def test_integer_parameter_serialization(self):
diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp
index 660e56a..6998601 100644
--- a/torch/csrc/cuda/Module.cpp
+++ b/torch/csrc/cuda/Module.cpp
@@ -17,6 +17,7 @@
 #include <torch/csrc/cuda/THCP.h>
 #include <torch/csrc/CudaIPCTypes.h>
 #include <torch/csrc/utils/pybind.h>
+#include <torch/csrc/utils/cuda_lazy_init.h>
 #include <torch/csrc/autograd/generated/VariableType.h>
 #include <torch/csrc/utils/python_strings.h>
 #include <torch/csrc/cuda/python_comm.h>
@@ -42,6 +43,7 @@
   THPUtils_assert(THPUtils_checkLong(arg), "invalid argument to setDevice");
   int64_t device = THPUtils_unpackLong(arg);
 
+  torch::utils::cuda_lazy_init();
   THCPModule_setDevice(device);
 
   Py_RETURN_NONE;
@@ -52,6 +54,7 @@
 {
   HANDLE_TH_ERRORS
   int device;
+  torch::utils::cuda_lazy_init();
   THCudaCheck(cudaGetDevice(&device));
   return PyLong_FromLong(device);
   END_HANDLE_TH_ERRORS
@@ -60,10 +63,19 @@
 PyObject * THCPModule_getDeviceCount_wrap(PyObject *self)
 {
   HANDLE_TH_ERRORS
+  //torch::utils::cuda_lazy_init();
   return PyLong_FromLong(at::cuda::device_count());
   END_HANDLE_TH_ERRORS
 }
 
+PyObject * THCPModule_set_run_yet_variable_to_false_wrap(PyObject *self)
+{
+  HANDLE_TH_ERRORS
+  torch::utils::set_run_yet_variable_to_false();
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 PyObject * THCPModule_getCurrentStream_wrap(
     PyObject * /* unused */, PyObject *device_index) {
   HANDLE_TH_ERRORS
@@ -387,6 +399,8 @@
   {"_cuda_setDevice",   (PyCFunction)THCPModule_setDevice_wrap,   METH_O,       nullptr},
   {"_cuda_getDevice",   (PyCFunction)THCPModule_getDevice_wrap,   METH_NOARGS,  nullptr},
   {"_cuda_getDeviceCount", (PyCFunction)THCPModule_getDeviceCount_wrap, METH_NOARGS, nullptr},
+  {"_cuda_set_run_yet_variable_to_false",
+    (PyCFunction)THCPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr},
   {"_cuda_getCurrentStream",
     (PyCFunction)THCPModule_getCurrentStream_wrap, METH_O, nullptr},
   {"_cuda_getDefaultStream",
diff --git a/torch/csrc/utils/cuda_lazy_init.cpp b/torch/csrc/utils/cuda_lazy_init.cpp
index 787b22f..cee8771 100644
--- a/torch/csrc/utils/cuda_lazy_init.cpp
+++ b/torch/csrc/utils/cuda_lazy_init.cpp
@@ -5,9 +5,10 @@
 
 #include <torch/csrc/Exceptions.h>
 #include <torch/csrc/utils/object_ptr.h>
-
 namespace torch {
 namespace utils {
+  
+static bool run_yet = false;
 
 void cuda_lazy_init() {
   AutoGIL g;
@@ -15,7 +16,6 @@
   // has a buggy implementation that deadlocks if an instance throws an
   // exception.  In any case, call_once isn't necessary, because we
   // have taken a lock.
-  static bool run_yet = false;
   if (!run_yet) {
     auto module = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
     if (!module) throw python_error();
@@ -25,5 +25,9 @@
   }
 }
 
+void set_run_yet_variable_to_false() {
+  run_yet = false;
+}
+
 }
 }
diff --git a/torch/csrc/utils/cuda_lazy_init.h b/torch/csrc/utils/cuda_lazy_init.h
index f8522c1..f0c5336 100644
--- a/torch/csrc/utils/cuda_lazy_init.h
+++ b/torch/csrc/utils/cuda_lazy_init.h
@@ -19,6 +19,7 @@
 // build, which is not good UX.
 //
 void cuda_lazy_init();
+void set_run_yet_variable_to_false();
 
 }
 }
diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py
index cc89b42..411cfb7 100644
--- a/torch/cuda/__init__.py
+++ b/torch/cuda/__init__.py
@@ -199,7 +199,7 @@
         _initialized = False
         _in_bad_fork = True
         _CudaBase.__new__ = _lazy_new
-
+        torch._C._cuda_set_run_yet_variable_to_false()
 
 _register_after_fork(_after_fork, _after_fork)