make torch.set_num_threads also set MKL threads (take 2) (#5002)
* torch.set_num_threads sets MKL option too
* fix to use C prototype instead of fortran
diff --git a/aten/src/TH/THGeneral.c b/aten/src/TH/THGeneral.c
index 70da018..ff71507 100644
--- a/aten/src/TH/THGeneral.c
+++ b/aten/src/TH/THGeneral.c
@@ -21,6 +21,13 @@
#include <malloc/malloc.h>
#endif
+#ifdef TH_BLAS_MKL
+// this is the C prototype, while mkl_set_num_threads is the fortran prototype
+extern void MKL_Set_Num_Threads(int);
+// this is the C prototype, while mkl_get_max_threads is the fortran prototype
+extern int MKL_Get_Max_Threads(void);
+#endif
+
/* Torch Error Handling */
static void defaultErrorHandlerFunction(const char *msg, void *data)
{
@@ -302,6 +309,10 @@
#ifdef _OPENMP
omp_set_num_threads(num_threads);
#endif
+#ifdef TH_BLAS_MKL
+ MKL_Set_Num_Threads(num_threads);
+#endif
+
}
int THGetNumThreads(void)
@@ -322,10 +333,6 @@
#endif
}
-#ifdef TH_BLAS_MKL
-extern int mkl_get_max_threads(void);
-#endif
-
TH_API void THInferNumThreads(void)
{
#if defined(_OPENMP) && defined(TH_BLAS_MKL)
@@ -333,7 +340,7 @@
// Otherwise, MKL and our OpenMP-enabled functions will keep changing the
// size of the OpenMP thread pool, resulting in worse performance (and memory
// leaks in GCC 5.4)
- omp_set_num_threads(mkl_get_max_threads());
+ omp_set_num_threads(MKL_Get_Max_Threads());
#endif
}