Move OMP/MKL thread initialization into ATen/Parallel (#19011)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19011
ghimport-source-id: 432e31eccfd0e59fa21a790f861e6b2ff4fdbac6

Differential Revision: D14846034

Pulled By: ilia-cher

fbshipit-source-id: d9d03c761d34bac80e09ce776e41c20fd3b04389
diff --git a/aten/src/ATen/ATen.h b/aten/src/ATen/ATen.h
index 16adfc3..3d55be6 100644
--- a/aten/src/ATen/ATen.h
+++ b/aten/src/ATen/ATen.h
@@ -1,7 +1,7 @@
 #pragma once
 
 #include <c10/core/Allocator.h>
-#include <ATen/CPUGeneral.h>
+#include <ATen/core/ATenGeneral.h>
 #include <ATen/Context.h>
 #include <ATen/Device.h>
 #include <ATen/DeviceGuard.h>
diff --git a/aten/src/ATen/CPUGeneral.cpp b/aten/src/ATen/CPUGeneral.cpp
deleted file mode 100644
index 910e3ae..0000000
--- a/aten/src/ATen/CPUGeneral.cpp
+++ /dev/null
@@ -1,16 +0,0 @@
-#include <ATen/CPUGeneral.h>
-#include <atomic>
-#include <memory>
-#include <thread>
-
-namespace at {
-// Lock free atomic type
-std::atomic<int> num_threads(-1);
-
-void set_num_threads(int num_threads_) {
-  if (num_threads_ >= 0)
-    num_threads.store(num_threads_);
-}
-
-int get_num_threads() { return num_threads.load(); }
-}
diff --git a/aten/src/ATen/CPUGeneral.h b/aten/src/ATen/CPUGeneral.h
deleted file mode 100644
index 246af37..0000000
--- a/aten/src/ATen/CPUGeneral.h
+++ /dev/null
@@ -1,12 +0,0 @@
-#pragma once
-
-// Using CAFFE2_API is crucial as otherwise you'll see
-// linking errors using MSVC
-// See https://msdn.microsoft.com/en-us/library/a90k134d.aspx
-// This header adds this if using CAFFE2_API
-#include <ATen/core/ATenGeneral.h>
-
-namespace at {
-CAFFE2_API void set_num_threads(int);
-CAFFE2_API int get_num_threads();
-}
diff --git a/aten/src/ATen/Context.h b/aten/src/ATen/Context.h
index 9033494..925b327 100644
--- a/aten/src/ATen/Context.h
+++ b/aten/src/ATen/Context.h
@@ -1,6 +1,6 @@
 #pragma once
 
-#include <ATen/CPUGeneral.h>
+#include <ATen/core/ATenGeneral.h>
 #include <ATen/Type.h>
 #include <ATen/TypeExtendedInterface.h>
 #include <ATen/Utils.h>
@@ -164,12 +164,6 @@
 
 static inline void init() {
   globalContext();
-  if (const char *env_p = std::getenv("OMP_NUM_THREADS")) {
-    at::set_num_threads(std::stoi(env_p));
-  }
-  if (const char *env_p = std::getenv("MKL_NUM_THREADS")) {
-    at::set_num_threads(std::stoi(env_p));
-  }
 }
 
 static inline TypeExtendedInterface& getNonVariableType(Backend p, ScalarType s) {
diff --git a/aten/src/ATen/Parallel.cpp b/aten/src/ATen/Parallel.cpp
new file mode 100644
index 0000000..3345e09
--- /dev/null
+++ b/aten/src/ATen/Parallel.cpp
@@ -0,0 +1,63 @@
+#include <ATen/Parallel.h>
+
+#include <atomic>
+
+#ifdef TH_BLAS_MKL
+#include <mkl.h>
+#endif
+
+namespace at {
+
+namespace {
+// Number of threads set by the user
+std::atomic<int> num_threads(-1);
+}
+
+void init_num_threads() {
+  auto nthreads = num_threads.load();
+  if (nthreads > 0) {
+    set_num_threads(nthreads);
+  } else {
+#if defined(_OPENMP) && defined(TH_BLAS_MKL)
+  // If we are using MKL an OpenMP make sure the number of threads match.
+  // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
+  // size of the OpenMP thread pool, resulting in worse performance (and memory
+  // leaks in GCC 5.4)
+  omp_set_num_threads(mkl_get_max_threads());
+#endif
+  }
+}
+
+void set_num_threads(size_t nthreads) {
+  if (nthreads == 0) {
+    return;
+  }
+  num_threads.store(nthreads);
+#ifdef _OPENMP
+  omp_set_num_threads(nthreads);
+#endif
+#ifdef TH_BLAS_MKL
+  mkl_set_num_threads(nthreads);
+
+  // because PyTorch uses OpenMP outside of MKL invocations
+  // as well, we want this flag to be false, so that
+  // threads aren't destroyed and recreated across every
+  // MKL / non-MKL boundary of OpenMP usage
+  // See https://github.com/pytorch/pytorch/issues/13757
+  mkl_set_dynamic(false);
+#endif
+}
+
+// Explicitly calling omp_get_max_threads() as the size of the parallel
+// region might be different in the new thread;
+// Use init_num_threads() during thread initialization to ensure
+// consistent size of parallel region in different threads
+size_t get_num_threads() {
+#ifdef _OPENMP
+  return omp_get_max_threads();
+#else
+  return 1;
+#endif
+}
+
+}
diff --git a/aten/src/ATen/Parallel.h b/aten/src/ATen/Parallel.h
index fe073f1..0d00a9e 100644
--- a/aten/src/ATen/Parallel.h
+++ b/aten/src/ATen/Parallel.h
@@ -22,14 +22,17 @@
   return (x + y - 1) / y;
 }
 
-inline int get_max_threads() {
-#ifdef _OPENMP
-  return omp_get_max_threads();
-#else
-  return 1;
-#endif
-}
+// Called during new thread initialization
+C10_API void init_num_threads();
 
+// Sets the number of threads to be used in parallel region
+C10_API void set_num_threads(size_t);
+
+// Returns the number of threads used in parallel region
+C10_API size_t get_num_threads();
+
+// Returns the current thread number (starting from 0)
+// in the current parallel region, or 0 in the sequential region
 inline int get_thread_num() {
 #ifdef _OPENMP
   return omp_get_thread_num();
diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp
index 3da24a1..d9ba8fa 100644
--- a/aten/src/ATen/native/BatchLinearAlgebra.cpp
+++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp
@@ -73,7 +73,7 @@
 template<class scalar_t>
 void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
   AT_ERROR("triangular_solve only takes float or double Tensors");
-} 
+}
 
 #ifdef USE_LAPACK
 template<> void lapackSolve<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
diff --git a/aten/src/ATen/native/TensorIterator.cpp b/aten/src/ATen/native/TensorIterator.cpp
index e86444a..ae823a3 100644
--- a/aten/src/ATen/native/TensorIterator.cpp
+++ b/aten/src/ATen/native/TensorIterator.cpp
@@ -357,7 +357,7 @@
   int64_t numel = this->numel();
   if (numel == 0) {
     return;
-  } else if (numel < internal::GRAIN_SIZE || at::get_max_threads() == 1) {
+  } else if (numel < internal::GRAIN_SIZE || at::get_num_threads() == 1) {
     return serial_for_each(loop, {0, numel});
   } else {
     at::parallel_for(0, numel, internal::GRAIN_SIZE, [&](int64_t begin, int64_t end) {
diff --git a/aten/src/ATen/native/TensorIteratorReduce.cpp b/aten/src/ATen/native/TensorIteratorReduce.cpp
index f6d2028..228bb13 100644
--- a/aten/src/ATen/native/TensorIteratorReduce.cpp
+++ b/aten/src/ATen/native/TensorIteratorReduce.cpp
@@ -16,7 +16,8 @@
 void TensorIterator::parallel_reduce(const loop2d_t& loop) {
   AT_CHECK(ntensors() == 2, "parallel_reduce only supports one input and one output");
   int64_t numel = this->numel();
-  if (numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region()) {
+  if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+      at::in_parallel_region()) {
     serial_for_each(loop, {0, numel});
   } else if (use_two_pass_reduction(*this)) {
     two_pass_reduction(*this, loop);
@@ -30,7 +31,7 @@
 }
 
 static void two_pass_reduction(TensorIterator& iter, const loop2d_t& loop) {
-  int max_threads = at::get_max_threads();
+  int max_threads = at::get_num_threads();
 
   auto& dst = iter.tensor(0);
   auto buffer_shape = DimVector(dst.sizes());
@@ -65,7 +66,7 @@
 /// Chooses a dimension over which to parallelize. Prefers the outer-most
 /// dimension thats larger than the number of available threads.
 static int find_split_dim(TensorIterator& iter) {
-  int num_threads = at::get_max_threads();
+  int num_threads = at::get_num_threads();
   auto shape = iter.shape();
 
   // start with the outer-most dimension
@@ -125,7 +126,8 @@
   if (tensor(0).numel() == 1) {
     loop(*this);
   }
-  else if (numel() < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region() || !parallelize) {
+  else if (numel() < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+      at::in_parallel_region() || !parallelize) {
     auto reduce_dims = num_reduce_dims();
 
     auto non_reduced_shape = shape.slice(reduce_dims, shape.size() - reduce_dims);
@@ -154,7 +156,7 @@
 
       sub_iter.narrow(dim, begin, end - begin);
       // On some broken setups, `#ifdef _OPENMP` is true,
-      // and `get_max_threads` returns > 1, but
+      // and `get_num_threads` returns > 1, but
       // `#pragma omp parallel` is ignored.
       // There is no API to check for this, so we need to explicitly
       // stop trying to parallelize if we've already gotten here.
diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h
index 7500781..dbc469a 100644
--- a/aten/src/ATen/native/cpu/Reduce.h
+++ b/aten/src/ATen/native/cpu/Reduce.h
@@ -97,10 +97,11 @@
     };
     acc_t total_acc = init;
     auto numel = sub_iter.numel();
-    if (numel < at::internal::GRAIN_SIZE || at::get_max_threads() == 1 || at::in_parallel_region()) {
+    if (numel < at::internal::GRAIN_SIZE || at::get_num_threads() == 1 ||
+        at::in_parallel_region()) {
       total_acc = reduction_body(total_acc, 0, numel);
     } else {
-      int max_threads = at::get_max_threads();
+      int max_threads = at::get_num_threads();
       AT_ASSERT(max_threads > 0);
       static_assert(
         !std::is_same<acc_t, bool>::value,
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index bcee7ef..20b643c 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -19,7 +19,7 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/test_parallel.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
-  ${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/thread_init_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/quantized_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp
diff --git a/aten/src/ATen/test/tbb_init_test.cpp b/aten/src/ATen/test/thread_init_test.cpp
similarity index 73%
rename from aten/src/ATen/test/tbb_init_test.cpp
rename to aten/src/ATen/test/thread_init_test.cpp
index 5e17fae..adbb324 100644
--- a/aten/src/ATen/test/tbb_init_test.cpp
+++ b/aten/src/ATen/test/thread_init_test.cpp
@@ -9,24 +9,24 @@
 // will throw an exception when multiple threads call
 // their first parallel construct.
 void test(int given_num_threads) {
+  at::init_num_threads();
   auto t = at::ones({1000 * 1000}, at::CPU(at::kFloat));
-  if (given_num_threads >= 0) {
-    ASSERT(at::get_num_threads() == given_num_threads);
-  } else {
-    ASSERT(at::get_num_threads() == -1);
-  }
+  ASSERT(given_num_threads >= 0);
+  ASSERT(at::get_num_threads() == given_num_threads);
   auto t_sum = t.sum();
-  for (int i = 0; i < 1000; i ++) {
+  for (int i = 0; i < 1000; ++i) {
     t_sum = t_sum + t.sum();
   }
 }
 
 int main() {
+  at::init_num_threads();
   at::manual_seed(123);
 
-  test(-1);
-  std::thread t1(test, -1);
+  test(at::get_num_threads());
+  std::thread t1(test, at::get_num_threads());
   t1.join();
+
   at::set_num_threads(4);
   std::thread t2(test, 4);
   std::thread t3(test, 4);
@@ -34,6 +34,7 @@
   t4.join();
   t3.join();
   t2.join();
+
   at::set_num_threads(5);
   test(5);
 
diff --git a/aten/src/TH/THGeneral.cpp b/aten/src/TH/THGeneral.cpp
index 8df8ff5..f8f8a3b 100644
--- a/aten/src/TH/THGeneral.cpp
+++ b/aten/src/TH/THGeneral.cpp
@@ -4,10 +4,6 @@
 #include <c10/core/CPUAllocator.h>
 #endif
 
-#ifdef _OPENMP
-#include <omp.h>
-#endif
-
 #ifndef TH_HAVE_THREAD
 #define __thread
 #elif _MSC_VER
@@ -24,10 +20,6 @@
 #include <malloc/malloc.h>
 #endif
 
-#ifdef TH_BLAS_MKL
-#include <mkl.h>
-#endif
-
 /* Torch Error Handling */
 static void defaultErrorHandlerFunction(const char *msg, void *data)
 {
@@ -224,53 +216,6 @@
   return expm1(x);
 }
 
-void THSetNumThreads(int num_threads)
-{
-#ifdef _OPENMP
-  omp_set_num_threads(num_threads);
-#endif
-#ifdef TH_BLAS_MKL
-  mkl_set_num_threads(num_threads);
-
-  // because PyTorch uses OpenMP outside of MKL invocations
-  // as well, we want this flag to be false, so that
-  // threads aren't destroyed and recreated across every
-  // MKL / non-MKL boundary of OpenMP usage
-  // See https://github.com/pytorch/pytorch/issues/13757
-  mkl_set_dynamic(false);
-#endif
-
-}
-
-int THGetNumThreads(void)
-{
-#ifdef _OPENMP
-  return omp_get_max_threads();
-#else
-  return 1;
-#endif
-}
-
-int THGetNumCores(void)
-{
-#ifdef _OPENMP
-  return omp_get_num_procs();
-#else
-  return 1;
-#endif
-}
-
-TH_API void THInferNumThreads(void)
-{
-#if defined(_OPENMP) && defined(TH_BLAS_MKL)
-  // If we are using MKL an OpenMP make sure the number of threads match.
-  // Otherwise, MKL and our OpenMP-enabled functions will keep changing the
-  // size of the OpenMP thread pool, resulting in worse performance (and memory
-  // leaks in GCC 5.4)
-  omp_set_num_threads(mkl_get_max_threads());
-#endif
-}
-
 THDescBuff _THSizeDesc(const int64_t *size, const int64_t ndim) {
   const int L = TH_DESC_BUFF_LEN;
   THDescBuff buf;
diff --git a/aten/src/TH/THGeneral.h.in b/aten/src/TH/THGeneral.h.in
index 895ebf4..f5ea797 100644
--- a/aten/src/TH/THGeneral.h.in
+++ b/aten/src/TH/THGeneral.h.in
@@ -100,10 +100,6 @@
 TH_API void THSetGCHandler( void (*torchGCHandlerFunction)(void *data), void *data );
 // this hook should only be called by custom allocator functions
 TH_API void THHeapUpdate(ptrdiff_t size);
-TH_API void THSetNumThreads(int num_threads);
-TH_API int THGetNumThreads(void);
-TH_API int THGetNumCores(void);
-TH_API void THInferNumThreads(void);
 
 #define THError(...) _THError(__FILE__, __LINE__, __VA_ARGS__)
 
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 522fcca..98512af 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -2094,7 +2094,7 @@
            r"""
 get_num_threads() -> int
 
-Gets the number of OpenMP threads used for parallelizing CPU operations
+Gets the number of threads used for parallelizing CPU operations
 """)
 
 add_docstr(torch.gt,
@@ -4282,7 +4282,10 @@
            r"""
 set_num_threads(int)
 
-Sets the number of OpenMP threads used for parallelizing CPU operations
+Sets the number of threads used for parallelizing CPU operations.
+WARNING:
+To ensure that the correct number of threads is used, set_num_threads
+must be called before running eager, JIT or autograd code.
 """)
 
 add_docstr(torch.sigmoid,
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index df44a3e..9d47ae7 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -13,6 +13,7 @@
 #include <ATen/ExpandUtils.h>
 #include <ATen/dlpack.h>
 #include <ATen/DLConvertor.h>
+#include <ATen/Parallel.h>
 #include <ATen/Utils.h>
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
@@ -147,14 +148,13 @@
 
 static PyObject * THPModule_getNumThreads(PyObject *module)
 {
-  return PyLong_FromLong(THGetNumThreads());
+  return PyLong_FromLong(at::get_num_threads());
 }
 
 static PyObject * THPModule_setNumThreads(PyObject *module, PyObject *arg)
 {
   THPUtils_assert(THPUtils_checkLong(arg), "set_num_threads expects an int, "
           "but got %s", THPUtils_typename(arg));
-  THSetNumThreads((int)THPUtils_unpackLong(arg));
   at::set_num_threads((int)THPUtils_unpackLong(arg));
   Py_RETURN_NONE;
 }
@@ -541,7 +541,7 @@
 #endif
 PyObject* initModule() {
   HANDLE_TH_ERRORS
-  THInferNumThreads();
+  at::init_num_threads();
 
 #define ASSERT_TRUE(cmd) if (!(cmd)) return nullptr
 
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 8882edc..b4716dd 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -9,6 +9,7 @@
 
 #include <ATen/DeviceGuard.h>
 #include <ATen/ExpandUtils.h>
+#include <ATen/Parallel.h>
 #include <c10/util/Exception.h>
 
 #include <atomic>
@@ -204,7 +205,7 @@
 Engine::~Engine() = default;
 
 auto Engine::thread_init(int device) -> void {
-  THInferNumThreads();
+  at::init_num_threads();
   // Note [Allocating GPUs to autograd threads]
   // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   // What's our strategy here?  Originally, the autograd engine was written