| #include "THGeneral.h" |
| #include "THAtomic.h" |
| |
| #ifdef _OPENMP |
| #include <omp.h> |
| #endif |
| |
| #ifndef TH_HAVE_THREAD |
| #define __thread |
| #elif _MSC_VER |
| #define __thread __declspec( thread ) |
| #endif |
| |
| #if (defined(__unix) || defined(_WIN32)) |
| #if defined(__FreeBSD__) |
| #include <malloc_np.h> |
| #else |
| #include <malloc.h> |
| #endif |
| #elif defined(__APPLE__) |
| #include <malloc/malloc.h> |
| #endif |
| |
| #ifdef TH_BLAS_MKL |
| // this is the C prototype, while mkl_set_num_threads is the fortran prototype |
| extern void MKL_Set_Num_Threads(int); |
| // this is the C prototype, while mkl_get_max_threads is the fortran prototype |
| extern int MKL_Get_Max_Threads(void); |
| #endif |
| |
| /* Torch Error Handling */ |
| static void defaultErrorHandlerFunction(const char *msg, void *data) |
| { |
| printf("$ Error: %s\n", msg); |
| exit(-1); |
| } |
| |
| static THErrorHandlerFunction defaultErrorHandler = defaultErrorHandlerFunction; |
| static void *defaultErrorHandlerData; |
| static __thread THErrorHandlerFunction threadErrorHandler = NULL; |
| static __thread void *threadErrorHandlerData; |
| |
| void _THError(const char *file, const int line, const char *fmt, ...) |
| { |
| char msg[2048]; |
| va_list args; |
| |
| /* vasprintf not standard */ |
| /* vsnprintf: how to handle if does not exists? */ |
| va_start(args, fmt); |
| int n = vsnprintf(msg, 2048, fmt, args); |
| va_end(args); |
| |
| if(n < 2048) { |
| snprintf(msg + n, 2048 - n, " at %s:%d", file, line); |
| } |
| |
| if (threadErrorHandler) |
| (*threadErrorHandler)(msg, threadErrorHandlerData); |
| else |
| (*defaultErrorHandler)(msg, defaultErrorHandlerData); |
| TH_UNREACHABLE; |
| } |
| |
| void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...) { |
| char msg[1024]; |
| va_list args; |
| va_start(args, fmt); |
| vsnprintf(msg, 1024, fmt, args); |
| va_end(args); |
| _THError(file, line, "Assertion `%s' failed. %s", exp, msg); |
| } |
| |
| void THSetErrorHandler(THErrorHandlerFunction new_handler, void *data) |
| { |
| threadErrorHandler = new_handler; |
| threadErrorHandlerData = data; |
| } |
| |
| void THSetDefaultErrorHandler(THErrorHandlerFunction new_handler, void *data) |
| { |
| if (new_handler) |
| defaultErrorHandler = new_handler; |
| else |
| defaultErrorHandler = defaultErrorHandlerFunction; |
| defaultErrorHandlerData = data; |
| } |
| |
| /* Torch Arg Checking Handling */ |
| static void defaultArgErrorHandlerFunction(int argNumber, const char *msg, void *data) |
| { |
| if(msg) |
| printf("$ Invalid argument %d: %s\n", argNumber, msg); |
| else |
| printf("$ Invalid argument %d\n", argNumber); |
| exit(-1); |
| } |
| |
| static THArgErrorHandlerFunction defaultArgErrorHandler = defaultArgErrorHandlerFunction; |
| static void *defaultArgErrorHandlerData; |
| static __thread THArgErrorHandlerFunction threadArgErrorHandler = NULL; |
| static __thread void *threadArgErrorHandlerData; |
| |
| void _THArgCheck(const char *file, int line, int condition, int argNumber, const char *fmt, ...) |
| { |
| if(!condition) { |
| char msg[2048]; |
| va_list args; |
| |
| /* vasprintf not standard */ |
| /* vsnprintf: how to handle if does not exists? */ |
| va_start(args, fmt); |
| int n = vsnprintf(msg, 2048, fmt, args); |
| va_end(args); |
| |
| if(n < 2048) { |
| snprintf(msg + n, 2048 - n, " at %s:%d", file, line); |
| } |
| |
| if (threadArgErrorHandler) |
| (*threadArgErrorHandler)(argNumber, msg, threadArgErrorHandlerData); |
| else |
| (*defaultArgErrorHandler)(argNumber, msg, defaultArgErrorHandlerData); |
| TH_UNREACHABLE; |
| } |
| } |
| |
| void THSetArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data) |
| { |
| threadArgErrorHandler = new_handler; |
| threadArgErrorHandlerData = data; |
| } |
| |
| void THSetDefaultArgErrorHandler(THArgErrorHandlerFunction new_handler, void *data) |
| { |
| if (new_handler) |
| defaultArgErrorHandler = new_handler; |
| else |
| defaultArgErrorHandler = defaultArgErrorHandlerFunction; |
| defaultArgErrorHandlerData = data; |
| } |
| |
| static __thread void (*torchGCFunction)(void *data) = NULL; |
| static __thread void *torchGCData; |
| static ptrdiff_t heapSize = 0; |
| static __thread ptrdiff_t heapDelta = 0; |
| static const ptrdiff_t heapMaxDelta = (ptrdiff_t)1e6; // limit to +/- 1MB before updating heapSize |
| static const ptrdiff_t heapMinDelta = (ptrdiff_t)-1e6; |
| static __thread ptrdiff_t heapSoftmax = (ptrdiff_t)3e8; // 300MB, adjusted upward dynamically |
| static const double heapSoftmaxGrowthThresh = 0.8; // grow softmax if >80% max after GC |
| static const double heapSoftmaxGrowthFactor = 1.4; // grow softmax by 40% |
| |
| /* Optional hook for integrating with a garbage-collected frontend. |
| * |
| * If torch is running with a garbage-collected frontend (e.g. Lua), |
| * the GC isn't aware of TH-allocated memory so may not know when it |
| * needs to run. These hooks trigger the GC to run in two cases: |
| * |
| * (1) When a memory allocation (malloc, realloc, ...) fails |
| * (2) When the total TH-allocated memory hits a dynamically-adjusted |
| * soft maximum. |
| */ |
| void THSetGCHandler( void (*torchGCFunction_)(void *data), void *data ) |
| { |
| torchGCFunction = torchGCFunction_; |
| torchGCData = data; |
| } |
| |
| /* it is guaranteed the allocated size is not bigger than PTRDIFF_MAX */ |
| static ptrdiff_t getAllocSize(void *ptr) { |
| #if defined(__unix) && defined(HAVE_MALLOC_USABLE_SIZE) |
| return malloc_usable_size(ptr); |
| #elif defined(__APPLE__) |
| return malloc_size(ptr); |
| #elif defined(_WIN32) |
| if(ptr) { return _msize(ptr); } else { return 0; } |
| #else |
| return 0; |
| #endif |
| } |
| |
| static ptrdiff_t applyHeapDelta() { |
| ptrdiff_t oldHeapSize = THAtomicAddPtrdiff(&heapSize, heapDelta); |
| #ifdef DEBUG |
| if (heapDelta > 0 && oldHeapSize > PTRDIFF_MAX - heapDelta) |
| THError("applyHeapDelta: heapSize(%td) + increased(%td) > PTRDIFF_MAX, heapSize overflow!", oldHeapSize, heapDelta); |
| if (heapDelta < 0 && oldHeapSize < PTRDIFF_MIN - heapDelta) |
| THError("applyHeapDelta: heapSize(%td) + decreased(%td) < PTRDIFF_MIN, heapSize underflow!", oldHeapSize, heapDelta); |
| #endif |
| ptrdiff_t newHeapSize = oldHeapSize + heapDelta; |
| heapDelta = 0; |
| return newHeapSize; |
| } |
| |
| /* (1) if the torch-allocated heap size exceeds the soft max, run GC |
| * (2) if post-GC heap size exceeds 80% of the soft max, increase the |
| * soft max by 40% |
| */ |
| static void maybeTriggerGC(ptrdiff_t curHeapSize) { |
| if (torchGCFunction && curHeapSize > heapSoftmax) { |
| torchGCFunction(torchGCData); |
| |
| // ensure heapSize is accurate before updating heapSoftmax |
| ptrdiff_t newHeapSize = applyHeapDelta(); |
| |
| if (newHeapSize > heapSoftmax * heapSoftmaxGrowthThresh) { |
| heapSoftmax = (ptrdiff_t)(heapSoftmax * heapSoftmaxGrowthFactor); |
| } |
| } |
| } |
| |
| static void* THAllocInternal(ptrdiff_t size) |
| { |
| void *ptr; |
| |
| if (size > 5120) |
| { |
| #if (defined(__unix) || defined(__APPLE__)) && (!defined(DISABLE_POSIX_MEMALIGN)) |
| if (posix_memalign(&ptr, 64, size) != 0) |
| ptr = NULL; |
| /* |
| #elif defined(_WIN32) |
| ptr = _aligned_malloc(size, 64); |
| */ |
| #else |
| ptr = malloc(size); |
| #endif |
| } |
| else |
| { |
| ptr = malloc(size); |
| } |
| |
| return ptr; |
| } |
| |
| void* THAlloc(ptrdiff_t size) |
| { |
| void *ptr; |
| |
| if(size < 0) |
| THError("$ Torch: invalid memory size -- maybe an overflow?"); |
| |
| if(size == 0) |
| return NULL; |
| |
| ptr = THAllocInternal(size); |
| |
| if(!ptr && torchGCFunction) { |
| torchGCFunction(torchGCData); |
| ptr = THAllocInternal(size); |
| } |
| |
| if(!ptr) |
| THError("$ Torch: not enough memory: you tried to allocate %dGB. Buy new RAM!", size/1073741824); |
| |
| return ptr; |
| } |
| |
| void* THRealloc(void *ptr, ptrdiff_t size) |
| { |
| if(!ptr) |
| return(THAlloc(size)); |
| |
| if(size == 0) |
| { |
| THFree(ptr); |
| return NULL; |
| } |
| |
| if(size < 0) |
| THError("$ Torch: invalid memory size -- maybe an overflow?"); |
| |
| void *newptr = realloc(ptr, size); |
| |
| if(!newptr && torchGCFunction) { |
| torchGCFunction(torchGCData); |
| newptr = realloc(ptr, size); |
| } |
| |
| if(!newptr) |
| THError("$ Torch: not enough memory: you tried to reallocate %dGB. Buy new RAM!", size/1073741824); |
| |
| return newptr; |
| } |
| |
| void THFree(void *ptr) |
| { |
| free(ptr); |
| } |
| |
| double THLog1p(const double x) |
| { |
| #if (defined(_MSC_VER) || defined(__MINGW32__)) |
| volatile double y = 1 + x; |
| return log(y) - ((y-1)-x)/y ; /* cancels errors with IEEE arithmetic */ |
| #else |
| return log1p(x); |
| #endif |
| } |
| |
| double THExpm1(const double x) |
| { |
| return expm1(x); |
| } |
| |
| void THSetNumThreads(int num_threads) |
| { |
| #ifdef _OPENMP |
| omp_set_num_threads(num_threads); |
| #endif |
| #ifdef TH_BLAS_MKL |
| MKL_Set_Num_Threads(num_threads); |
| #endif |
| |
| } |
| |
| int THGetNumThreads(void) |
| { |
| #ifdef _OPENMP |
| return omp_get_max_threads(); |
| #else |
| return 1; |
| #endif |
| } |
| |
| int THGetNumCores(void) |
| { |
| #ifdef _OPENMP |
| return omp_get_num_procs(); |
| #else |
| return 1; |
| #endif |
| } |
| |
| TH_API void THInferNumThreads(void) |
| { |
| #if defined(_OPENMP) && defined(TH_BLAS_MKL) |
| // If we are using MKL an OpenMP make sure the number of threads match. |
| // Otherwise, MKL and our OpenMP-enabled functions will keep changing the |
| // size of the OpenMP thread pool, resulting in worse performance (and memory |
| // leaks in GCC 5.4) |
| omp_set_num_threads(MKL_Get_Max_Threads()); |
| #endif |
| } |
| |
| TH_API THDescBuff _THSizeDesc(const int64_t *size, const int64_t ndim) { |
| const int L = TH_DESC_BUFF_LEN; |
| THDescBuff buf; |
| char *str = buf.str; |
| int i, n = 0; |
| n += snprintf(str, L-n, "["); |
| |
| for (i = 0; i < ndim; i++) { |
| if (n >= L) break; |
| n += snprintf(str+n, L-n, "%" PRId64, size[i]); |
| if (i < ndim-1) { |
| n += snprintf(str+n, L-n, " x "); |
| } |
| } |
| |
| if (n < L - 2) { |
| snprintf(str+n, L-n, "]"); |
| } else { |
| snprintf(str+L-5, 5, "...]"); |
| } |
| |
| return buf; |
| } |