Refactor THCNumerics and add common math functions for at::Half (#10301)

Summary:
**Summary**: This PR is a followup of mruberry's https://github.com/pytorch/pytorch/pull/9318/. It tries to achieve the following:
- Specializing std common math functions for `at::Half` type.
- Create `CUDANumerics.cuh` to contain necessary parts from `THCNumerics.cuh`.
- Update `THCNumerics.cuh` with new usage and comments to  demonstrate the best practice for developers and hence, making way for its deprecation.
- Remove legacy/redundant code path.
- Remove unused CUDA HALF macros (see separate PR https://github.com/pytorch/pytorch/pull/10147)

**Comments**: `CUDANumerics.cuh` contains mathematical functions that are either not in the std namespace or are specialized for compilation with CUDA NVCC or CUDA NVRTC. This header is derived from the legacy `THCNumerics.cuh`. Following are some rationale behind why some functions were kept while others were removed:
- All arithmetic can now be done in ATen using binary cuda kernel  or CUDA tensor pointwise apply (check https://github.com/pytorch/pytorch/pull/8919 and `CUDAApplyUtils`). `at::Half` comparisons rely on implicit conversion to float.
- Functions that are c/c++ standard compliant, have been specialized for user defined types, for instance, the std namespace has been opened up for `at::Half`, that defines math function definitions for `at::Half`. Check `Half-inl.h`
- Some standard compliant functions are specialized here for performance reasons. For instance, `powi` is used for `pow` calculation on integral types. Moreover, `abs`, `isinf`, `isnan` are specialized to save one API call vs when used with std. Although this is subject to change, depending on if we really care about saving one API call.
- Numeric limits such as `max/min` is removed since they call standard defines. Moreover, numeric limits for
`at::Half` is present in `Half-inl.h`. I understood that HIP has some issue with `std::numeric_limits` and this the related github issue I found: https://github.com/ROCm-Developer-Tools/HIP/issues/374. AlexVlx mentions that the issue can be avoided by launching `std::numeric_limits` in `__device__`. Since, we are launching lambdas with device contexts, I don't see an issue why `std::numeric_limits` won't compile on HIP if launched with device context within a kernel, unless I am not aware of the real reason why max/min was there in THCNumerics in the first place. (Haven't ever tried a build with HIP).

Here are some reference PRs that was handy in refactoring TH into ATen:
- https://github.com/pytorch/pytorch/pull/6786
- https://github.com/pytorch/pytorch/pull/5475
- https://github.com/pytorch/pytorch/pull/9401
- https://github.com/pytorch/pytorch/pull/8689
- https://github.com/pytorch/pytorch/pull/8919
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10301

Differential Revision: D9204758

Pulled By: soumith

fbshipit-source-id: 09f489c1656458c02367b6cd31c3eeeca5acdc8a
diff --git a/aten/src/ATen/cuda/NumericLimits.cuh b/aten/src/ATen/cuda/NumericLimits.cuh
new file mode 100644
index 0000000..9869649
--- /dev/null
+++ b/aten/src/ATen/cuda/NumericLimits.cuh
@@ -0,0 +1,75 @@
+#pragma once
+
+#include <cuda.h>
+#include <limits.h>
+
+// NumericLimits.cuh is a holder for numeric limits definitions of commonly used
+// types. This header is very specific to ROCm HIP and may be removed in the future. 
+// This header is derived from the legacy THCNumerics.cuh.
+
+namespace at{
+
+template <typename T>
+struct numeric_limits {
+};
+
+// WARNING: the following at::numeric_limits definitions are there only to support
+//          HIP compilation for the moment. Use std::numeric_limits if you are not
+//          compiling for ROCm.
+//          from @colesbury: "The functions on numeric_limits aren't marked with 
+//          __device__ which is why they don't work with ROCm. CUDA allows them 
+//          because they're constexpr."
+template <>
+struct numeric_limits<uint8_t> {
+  static inline __host__ __device__ uint8_t lowest() { return 0; }
+  static inline __host__ __device__ uint8_t max() { return UINT8_MAX; }
+};
+
+template <>
+struct numeric_limits<int8_t> {
+  static inline __host__ __device__ int8_t lowest() { return INT8_MIN; }
+  static inline __host__ __device__ int8_t max() { return INT8_MAX; }
+};
+
+template <>
+struct numeric_limits<int16_t> {
+  static inline __host__ __device__ int16_t lowest() { return INT16_MIN; }
+  static inline __host__ __device__ int16_t max() { return INT16_MAX; }
+};
+
+template <>
+struct numeric_limits<int32_t> {
+  static inline __host__ __device__ int32_t lowest() { return INT32_MIN; }
+  static inline __host__ __device__ int32_t max() { return INT32_MAX; }
+};
+
+template <>
+struct numeric_limits<int64_t> {
+#ifdef _MSC_VER
+  static inline __host__ __device__ int64_t lowest() { return _I64_MIN; }
+  static inline __host__ __device__ int64_t max() { return _I64_MAX; }
+#else
+  static inline __host__ __device__ int64_t lowest() { return INT64_MIN; }
+  static inline __host__ __device__ int64_t max() { return INT64_MAX; }
+#endif
+};
+
+template <>
+struct numeric_limits<at::Half> {
+  static inline __host__ __device__ at::Half lowest() { return at::Half(0xFBFF, at::Half::from_bits); }
+  static inline __host__ __device__ at::Half max() { return at::Half(0x7BFF, at::Half::from_bits); }
+};
+
+template <>
+struct numeric_limits<float> {
+  static inline __host__ __device__ float lowest() { return -FLT_MAX; }
+  static inline __host__ __device__ float max() { return FLT_MAX; }
+};
+
+template <>
+struct numeric_limits<double> {
+  static inline __host__ __device__ double lowest() { return -DBL_MAX; }
+  static inline __host__ __device__ double max() { return DBL_MAX; }
+};
+
+} // namespace at
\ No newline at end of file
diff --git a/aten/src/ATen/native/cuda/SoftMax.cu b/aten/src/ATen/native/cuda/SoftMax.cu
index 6351b8a..433e1be 100644
--- a/aten/src/ATen/native/cuda/SoftMax.cu
+++ b/aten/src/ATen/native/cuda/SoftMax.cu
@@ -6,9 +6,9 @@
 #include <THC/THCTensorMathReduce.cuh>
 #include <THC/THCTensorSort.cuh>
 #include <THC/THCThrustAllocator.cuh>
-#include <THC/THCNumerics.cuh>
 
 #include "ATen/AccumulateType.h"
+#include "ATen/cuda/NumericLimits.cuh"
 
 
 namespace at {
@@ -200,7 +200,7 @@
       ////////////////////////////////////////////////////////////
 
       if (blockDim.x > 1) {
-        accscalar_t max_input = THCNumerics<accscalar_t>::min();
+        accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
           const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
           max_input = Max<accscalar_t>()(max_input, value);
@@ -217,7 +217,7 @@
         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x)
           output[data_offset + d * dim_stride] = epilogue(input[data_offset + d * dim_stride]);
       } else {
-        accscalar_t max_input = THCNumerics<accscalar_t>::min();
+        accscalar_t max_input = at::numeric_limits<accscalar_t>::lowest();
         for (uint32_t d = threadIdx.x; d < dim_size; d += blockDim.x) {
           const accscalar_t value = static_cast<accscalar_t>(input[data_offset + d * dim_stride]);
           max_input = Max<accscalar_t>()(max_input, value);
@@ -403,9 +403,9 @@
 
   // find the max
   accscalar_t threadMax = ilpReduce<MaxFloat, ILP, scalar_t, accscalar_t>(
-      input, classes, MaxFloat<scalar_t, accscalar_t>(), -THCNumerics<accscalar_t>::max());
+      input, classes, MaxFloat<scalar_t, accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
   accscalar_t max_k = blockReduce<Max, accscalar_t>(
-      sdata, threadMax, Max<accscalar_t>(), -THCNumerics<accscalar_t>::max());
+      sdata, threadMax, Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());
 
   // reduce all values
   accscalar_t threadExp = ilpReduce<SumExpFloat, ILP, scalar_t, accscalar_t>(
diff --git a/aten/src/ATen/test/CMakeLists.txt b/aten/src/ATen/test/CMakeLists.txt
index 25d84a3..190f9de 100644
--- a/aten/src/ATen/test/CMakeLists.txt
+++ b/aten/src/ATen/test/CMakeLists.txt
@@ -25,7 +25,8 @@
   ${CMAKE_CURRENT_SOURCE_DIR}/integer_divider_test.cu
   ${CMAKE_CURRENT_SOURCE_DIR}/cuda_rng_test.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/apply_test.cpp
-  ${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp)
+  ${CMAKE_CURRENT_SOURCE_DIR}/stream_test.cpp
+  ${CMAKE_CURRENT_SOURCE_DIR}/cuda_half_test.cu)
 if (CUDNN_FOUND)
   list(APPEND ATen_CUDA_TEST_SRCS
     ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_test.cpp)
diff --git a/aten/src/ATen/test/cuda_half_test.cu b/aten/src/ATen/test/cuda_half_test.cu
new file mode 100644
index 0000000..fa00e53
--- /dev/null
+++ b/aten/src/ATen/test/cuda_half_test.cu
@@ -0,0 +1,90 @@
+#define CATCH_CONFIG_MAIN
+#include "catch.hpp"
+
+#include "ATen/ATen.h"
+#include "ATen/cuda/NumericLimits.cuh"
+#include "cuda.h"
+#include "cuda_fp16.h"
+#include "cuda_runtime.h"
+
+#include <assert.h>
+
+using namespace at;
+
+__device__ void test(){
+  
+  // test half construction and implicit conversions in device
+  assert(Half(3) == Half(3.0f));
+  assert(static_cast<Half>(3.0f) == Half(3.0f));
+  // there is no float <=> __half implicit conversion
+  assert(static_cast<Half>(3.0f) == 3.0f);
+
+  __half a = __float2half(3.0f);
+  __half b = __float2half(2.0f);
+  __half c = a - Half(b);
+  assert(static_cast<Half>(c) == Half(1.0));
+
+  // asserting if the  functions used on 
+  // half types give almost equivalent results when using
+  //  functions on double.
+  // The purpose of these asserts are to test the device side
+  // half API for the common mathematical functions.
+  // Note: When calling std math functions from device, don't
+  // use the std namespace, but just "::" so that the function
+  // gets resolved from nvcc math_functions.hpp
+
+  float threshold = 0.00001;
+  assert(::abs(::lgamma(Half(10.0)) - ::lgamma(10.0f)) <= threshold);
+  assert(::abs(::exp(Half(1.0)) - ::exp(1.0f)) <= threshold);
+  assert(::abs(::log(Half(1.0)) - ::log(1.0f)) <= threshold);
+  assert(::abs(::log10(Half(1000.0)) - ::log10(1000.0f)) <= threshold);
+  assert(::abs(::log1p(Half(0.0)) - ::log1p(0.0f)) <= threshold);
+  assert(::abs(::log2(Half(1000.0)) - ::log2(1000.0f)) <= threshold);
+  assert(::abs(::expm1(Half(1.0)) - ::expm1(1.0f)) <= threshold);
+  assert(::abs(::cos(Half(0.0)) - ::cos(0.0f)) <= threshold);
+  assert(::abs(::sin(Half(0.0)) - ::sin(0.0f)) <= threshold);
+  assert(::abs(::sqrt(Half(100.0)) - ::sqrt(100.0f)) <= threshold);
+  assert(::abs(::ceil(Half(2.4)) - ::ceil(2.4f)) <= threshold);
+  assert(::abs(::floor(Half(2.7)) - ::floor(2.7f)) <= threshold);
+  assert(::abs(::trunc(Half(2.7)) - ::trunc(2.7f)) <= threshold);
+  assert(::abs(::acos(Half(-1.0)) - ::acos(-1.0f)) <= threshold);
+  assert(::abs(::cosh(Half(1.0)) - ::cosh(1.0f)) <= threshold);
+  assert(::abs(::acosh(Half(1.0)) - ::acosh(1.0f)) <= threshold);
+  assert(::abs(::asin(Half(1.0)) - ::asin(1.0f)) <= threshold);
+  assert(::abs(::sinh(Half(1.0)) - ::sinh(1.0f)) <= threshold);
+  assert(::abs(::asinh(Half(1.0)) - ::asinh(1.0f)) <= threshold);
+  assert(::abs(::tan(Half(0.0)) - ::tan(0.0f)) <= threshold);
+  assert(::abs(::atan(Half(1.0)) - ::atan(1.0f)) <= threshold);
+  assert(::abs(::tanh(Half(1.0)) - ::tanh(1.0f)) <= threshold);
+  assert(::abs(::erf(Half(10.0)) - ::erf(10.0f)) <= threshold);
+  assert(::abs(::erfc(Half(10.0)) - ::erfc(10.0f)) <= threshold);
+  assert(::abs(::abs(Half(-3.0)) - ::abs(-3.0f)) <= threshold);
+  assert(::abs(::round(Half(2.3)) - ::round(2.3f)) <= threshold);
+  assert(::abs(::pow(Half(2.0), Half(10.0)) - ::pow(2.0f, 10.0f)) <= threshold);
+  assert(::abs(::atan2(Half(7.0), Half(0.0)) - ::atan2(7.0f, 0.0f)) <= threshold);
+  // note: can't use  namespace on isnan and isinf in device code
+  #ifdef _MSC_VER
+    // Windows requires this explicit conversion. The reason is unclear
+    // related issue with clang: https://reviews.llvm.org/D37906
+    assert(::abs(::isnan((float)Half(0.0)) - ::isnan(0.0f)) <= threshold);
+    assert(::abs(::isinf((float)Half(0.0)) - ::isinf(0.0f)) <= threshold);
+  #else
+    assert(::abs(::isnan(Half(0.0)) - ::isnan(0.0f)) <= threshold);
+    assert(::abs(::isinf(Half(0.0)) - ::isinf(0.0f)) <= threshold);
+  #endif
+}
+
+__global__ void kernel(){
+  test();
+}
+
+void launch_function(){
+  kernel<<<1,1>>>();
+}
+
+TEST_CASE( "half common math functions tests in device", "[cuda]" ) {
+  launch_function();
+  cudaError_t err = cudaDeviceSynchronize();
+  REQUIRE(err == cudaSuccess);
+}
+
diff --git a/aten/src/ATen/test/half_test.cpp b/aten/src/ATen/test/half_test.cpp
index fc70522..3b29448 100644
--- a/aten/src/ATen/test/half_test.cpp
+++ b/aten/src/ATen/test/half_test.cpp
@@ -5,7 +5,10 @@
 #include <iostream>
 #include <limits>
 #include <sstream>
+#include <cmath>
 #include <type_traits>
+#include "test_seed.h"
+#include "test_assert.h"
 
 using namespace at;
 
@@ -115,3 +118,43 @@
 ASSERT_SAME_TYPE(max_exponent10);
 ASSERT_SAME_TYPE(traps);
 ASSERT_SAME_TYPE(tinyness_before);
+
+TEST_CASE( "half common math functions test", "[]" ) {
+  float threshold = 0.00001;
+  assert(std::abs(std::lgamma(Half(10.0)) - std::lgamma(10.0f)) <= threshold);
+  assert(std::abs(std::exp(Half(1.0)) - std::exp(1.0f)) <= threshold);
+  assert(std::abs(std::log(Half(1.0)) - std::log(1.0f)) <= threshold);
+  assert(std::abs(std::log10(Half(1000.0)) - std::log10(1000.0f)) <= threshold);
+  assert(std::abs(std::log1p(Half(0.0)) - std::log1p(0.0f)) <= threshold);
+  assert(std::abs(std::log2(Half(1000.0)) - std::log2(1000.0f)) <= threshold);
+  assert(std::abs(std::expm1(Half(1.0)) - std::expm1(1.0f)) <= threshold);
+  assert(std::abs(std::cos(Half(0.0)) - std::cos(0.0f)) <= threshold);
+  assert(std::abs(std::sin(Half(0.0)) - std::sin(0.0f)) <= threshold);
+  assert(std::abs(std::sqrt(Half(100.0)) - std::sqrt(100.0f)) <= threshold);
+  assert(std::abs(std::ceil(Half(2.4)) - std::ceil(2.4f)) <= threshold);
+  assert(std::abs(std::floor(Half(2.7)) - std::floor(2.7f)) <= threshold);
+  assert(std::abs(std::trunc(Half(2.7)) - std::trunc(2.7f)) <= threshold);
+  assert(std::abs(std::acos(Half(-1.0)) - std::acos(-1.0f)) <= threshold);
+  assert(std::abs(std::cosh(Half(1.0)) - std::cosh(1.0f)) <= threshold);
+  assert(std::abs(std::acosh(Half(1.0)) - std::acosh(1.0f)) <= threshold);
+  assert(std::abs(std::asin(Half(1.0)) - std::asin(1.0f)) <= threshold);
+  assert(std::abs(std::sinh(Half(1.0)) - std::sinh(1.0f)) <= threshold);
+  assert(std::abs(std::asinh(Half(1.0)) - std::asinh(1.0f)) <= threshold);
+  assert(std::abs(std::tan(Half(0.0)) - std::tan(0.0f)) <= threshold);
+  assert(std::abs(std::atan(Half(1.0)) - std::atan(1.0f)) <= threshold);
+  assert(std::abs(std::tanh(Half(1.0)) - std::tanh(1.0f)) <= threshold);
+  assert(std::abs(std::erf(Half(10.0)) - std::erf(10.0f)) <= threshold);
+  assert(std::abs(std::erfc(Half(10.0)) - std::erfc(10.0f)) <= threshold);
+  assert(std::abs(std::abs(Half(-3.0)) - std::abs(-3.0f)) <= threshold);
+  assert(std::abs(std::round(Half(2.3)) - std::round(2.3f)) <= threshold);
+  assert(std::abs(std::pow(Half(2.0), Half(10.0)) - std::pow(2.0f, 10.0f)) <= threshold);
+  assert(std::abs(std::atan2(Half(7.0), Half(0.0)) - std::atan2(7.0f, 0.0f)) <= threshold);
+  #ifdef __APPLE__
+    // @TODO: can macos do implicit conversion of Half?
+    assert(std::abs(std::isnan(static_cast<float>(Half(0.0))) - std::isnan(0.0f)) <= threshold);
+    assert(std::abs(std::isinf(static_cast<float>(Half(0.0))) - std::isinf(0.0f)) <= threshold);
+  #else
+    assert(std::abs(std::isnan(Half(0.0)) - std::isnan(0.0f)) <= threshold);
+    assert(std::abs(std::isinf(Half(0.0)) - std::isinf(0.0f)) <= threshold);
+  #endif
+}
\ No newline at end of file
diff --git a/aten/src/THC/CMakeLists.txt b/aten/src/THC/CMakeLists.txt
index 5f92ea3..44f5d18 100644
--- a/aten/src/THC/CMakeLists.txt
+++ b/aten/src/THC/CMakeLists.txt
@@ -18,10 +18,6 @@
    endforeach()
 endforeach()
 
-IF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
-  LIST(APPEND extra_src ${CMAKE_CURRENT_SOURCE_DIR}/THCHalf.cu)
-ENDIF()
-
 set(ATen_CUDA_SRCS ${ATen_CUDA_SRCS}
   ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingAllocator.cpp
   ${CMAKE_CURRENT_SOURCE_DIR}/THCCachingHostAllocator.cpp
diff --git a/aten/src/THC/THCAtomics.cuh b/aten/src/THC/THCAtomics.cuh
index f040699..485b574 100644
--- a/aten/src/THC/THCAtomics.cuh
+++ b/aten/src/THC/THCAtomics.cuh
@@ -4,8 +4,7 @@
 #include "THC.h"
 #include "THCHalf.h"
 #include "THCNumerics.cuh"
-
-namespace at { struct Half; }
+#include "ATen/ATen.h"
 
 template <typename T, size_t n>
 struct AtomicAddIntegerImpl;
@@ -118,8 +117,8 @@
     old = atomicCAS(address_as_ui, assumed, old);
   } while (assumed != old);
 }
-static inline __device__ void atomicAdd(at::Half *address, half val) {
-  return atomicAdd(reinterpret_cast<half*>(address), val);
+static inline __device__ void atomicAdd(at::Half *address, at::Half val) {
+  atomicAdd(reinterpret_cast<half*>(address), val);
 }
 
 #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 600 || CUDA_VERSION < 8000)
diff --git a/aten/src/THC/THCHalf.cu b/aten/src/THC/THCHalf.cu
deleted file mode 100644
index 7863260..0000000
--- a/aten/src/THC/THCHalf.cu
+++ /dev/null
@@ -1,51 +0,0 @@
-#include "THCHalf.h"
-#include "THCThrustAllocator.cuh"
-#include <thrust/transform.h>
-#include <thrust/execution_policy.h>
-
-struct __half2floatOp {
-  __device__ float operator()(half v) { return __half2float(v); }
-};
-
-struct __float2halfOp {
-  __device__ half operator()(float v) { return __float2half(v); }
-};
-
-void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len) {
-  THCThrustAllocator thrustAlloc(state);
-  thrust::transform(
-#if CUDA_VERSION >= 7000
-    thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
-#else
-    thrust::device,
-#endif
-    in, in + len, out, __float2halfOp());
-}
-
-void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len) {
-  THCThrustAllocator thrustAlloc(state);
-  thrust::transform(
-#if CUDA_VERSION >= 7000
-    thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
-#else
-    thrust::device,
-#endif
-    in, in + len, out, __half2floatOp());
-}
-
-THC_EXTERNC int THC_nativeHalfInstructions(THCState *state) {
-  cudaDeviceProp* prop =
-    THCState_getCurrentDeviceProperties(state);
-
-  // CC 5.3+
-  return (prop->major > 5 ||
-          (prop->major == 5 && prop->minor == 3));
-}
-
-THC_EXTERNC int THC_fastHalfInstructions(THCState *state) {
-  cudaDeviceProp* prop =
-    THCState_getCurrentDeviceProperties(state);
-
-  // Check for CC 6.0 only (corresponds to P100)
-  return (prop->major == 6 && prop->minor == 0);
-}
diff --git a/aten/src/THC/THCHalf.h b/aten/src/THC/THCHalf.h
index 6b9a4f7..aeae06f 100644
--- a/aten/src/THC/THCHalf.h
+++ b/aten/src/THC/THCHalf.h
@@ -12,15 +12,7 @@
 #endif
 #endif
 
-THC_EXTERNC void THCFloat2Half(THCState *state, half *out, float *in, ptrdiff_t len);
-THC_EXTERNC void THCHalf2Float(THCState *state, float *out, half *in, ptrdiff_t len);
 THC_API half THC_float2half(float a);
 THC_API float THC_half2float(half a);
 
-/* Check for native fp16 support on the current device (CC 5.3+) */
-THC_API int THC_nativeHalfInstructions(THCState *state);
-
-/* Check for performant native fp16 support on the current device */
-THC_API int THC_fastHalfInstructions(THCState *state);
-
 #endif
diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh
index 9e653b7..9af18f6 100644
--- a/aten/src/THC/THCNumerics.cuh
+++ b/aten/src/THC/THCNumerics.cuh
@@ -5,11 +5,18 @@
 #include <limits.h>
 #include <assert.h>
 #include "THCHalf.h"
+#include "ATen/ATen.h"
+#include "ATen/cuda/NumericLimits.cuh"
 
-/// Class for numeric limits of the particular data type, which
-/// includes support for `half`.
-/// Unfortunately since `half` does not have a constructor, these have
-/// to be expressed as functions (either that or non-const statics).
+// WARNING: THCNumerics is being deprecated. Please follow the comments
+// in this file to learn about new usages.
+// Comments on usage:
+//      - lt,le,gt,ge,eq,neg,add,mul,sub,div and other binary ops can
+//        be implemented using CUDA_apply_utils or binary cuda kernel
+//      - Check NumericLimits.cuh for specialized math functions.
+//      - Note how __half and at::Half can be casted. for instance:
+//        static_cast<at::Half>(std::sin(static_cast<at::Half>(a)));
+
 template <typename T>
 struct THCNumerics {
 };
@@ -28,10 +35,12 @@
   return result;
 }
 
+// DEPRECATED: For integral types, use math functions from std and NumericLimits.cuh. 
+//             Use binary_kernel or CUDA_apply_utils for arithmetic
 template <>
 struct THCNumerics<uint8_t> {
-  static inline __host__ __device__ uint8_t min() { return 0; }
-  static inline __host__ __device__ uint8_t max() { return UCHAR_MAX; }
+  static inline __host__ __device__ uint8_t min() { return at::numeric_limits<uint8_t>::lowest(); }
+  static inline __host__ __device__ uint8_t max() { return at::numeric_limits<uint8_t>::max(); }
 
   static inline __host__ __device__ bool lt(uint8_t a, uint8_t b) { return a < b; }
   static inline __host__ __device__ bool le(uint8_t a, uint8_t b) { return a <= b; }
@@ -53,8 +62,8 @@
 
 template <>
 struct THCNumerics<int8_t> {
-  static inline __host__ __device__ int8_t min() { return SCHAR_MIN; }
-  static inline __host__ __device__ int8_t max() { return SCHAR_MAX; }
+  static inline __host__ __device__ int8_t min() { return at::numeric_limits<int8_t>::lowest(); }
+  static inline __host__ __device__ int8_t max() { return at::numeric_limits<int8_t>::max(); }
 
   static inline __host__ __device__ bool lt(int8_t a, int8_t b) { return a < b; }
   static inline __host__ __device__ bool le(int8_t a, int8_t b) { return a <= b; }
@@ -76,8 +85,8 @@
 
 template <>
 struct THCNumerics<int16_t> {
-  static inline __host__ __device__ int16_t min() { return SHRT_MIN; }
-  static inline __host__ __device__ int16_t max() { return SHRT_MAX; }
+  static inline __host__ __device__ int16_t min() { return at::numeric_limits<int16_t>::lowest(); }
+  static inline __host__ __device__ int16_t max() { return at::numeric_limits<int16_t>::max(); }
 
   static inline __host__ __device__ bool lt(int16_t a, int16_t b) { return a < b; }
   static inline __host__ __device__ bool le(int16_t a, int16_t b) { return a <= b; }
@@ -99,8 +108,8 @@
 
 template <>
 struct THCNumerics<int32_t> {
-  static inline __host__ __device__ int32_t min() { return INT_MIN; }
-  static inline __host__ __device__ int32_t max() { return INT_MAX; }
+  static inline __host__ __device__ int32_t min() { return at::numeric_limits<int32_t>::lowest(); }
+  static inline __host__ __device__ int32_t max() { return at::numeric_limits<int32_t>::max(); }
 
   static inline __host__ __device__ bool lt(int32_t a, int32_t b) { return a < b; }
   static inline __host__ __device__ bool le(int32_t a, int32_t b) { return a <= b; }
@@ -122,13 +131,8 @@
 
 template <>
 struct THCNumerics<int64_t> {
-#ifdef _MSC_VER
-  static inline __host__ __device__ int64_t min() { return _I64_MIN; }
-  static inline __host__ __device__ int64_t max() { return _I64_MAX; }
-#else
-  static inline __host__ __device__ int64_t min() { return LONG_MIN; }
-  static inline __host__ __device__ int64_t max() { return LONG_MAX; }
-#endif
+  static inline __host__ __device__ int64_t min() { return at::numeric_limits<int64_t>::lowest(); }
+  static inline __host__ __device__ int64_t max() { return at::numeric_limits<int64_t>::max(); }
 
   static inline __host__ __device__ bool lt(int64_t a, int64_t b) { return a < b; }
   static inline __host__ __device__ bool le(int64_t a, int64_t b) { return a <= b; }
@@ -149,430 +153,222 @@
   static inline __host__ __device__  bool isinf(int64_t a) { return false; }
 };
 
+// DEPRECATED: use math functions from std and NumericLimits.cuh
 template <>
 struct THCNumerics<half> {
-#if CUDA_VERSION < 9000
-  static inline __host__ __device__ half min() { half h; h.x = 0xfbff; return h; }
-  static inline __host__ __device__ half max() { half h; h.x = 0x7bff; return h; }
-#else
-  static inline __host__ __device__ half min() { __half_raw h; h.x = 0xfbff; return h; }
-  static inline __host__ __device__ half max() { __half_raw h; h.x = 0x7bff; return h; }
-#endif
+  static inline __host__ __device__ half min() { return at::numeric_limits<at::Half>::lowest(); }
+  static inline __host__ __device__ half max() { return at::numeric_limits<at::Half>::max(); }
 
   static inline __host__ __device__ bool lt(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return fa < fb;
-#else // __CUDA_ARCH__
-    return THC_half2float(a) < THC_half2float(b);
-#endif
+    return static_cast<at::Half>(a) < static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ bool le(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return fa <= fb;
-#else // __CUDA_ARCH__
-    return THC_half2float(a) <= THC_half2float(b);
-#endif
+    return static_cast<at::Half>(a) <= static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ bool gt(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return fa > fb;
-#else // __CUDA_ARCH__
-    return THC_half2float(a) > THC_half2float(b);
-#endif
+    return static_cast<at::Half>(a) > static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ bool ge(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return fa >= fb;
-#else // __CUDA_ARCH__
-    return THC_half2float(a) >= THC_half2float(b);
-#endif
+    return static_cast<at::Half>(a) >= static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ bool eq(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return fa == fb;
-#else // __CUDA_ARCH__
-    return THC_half2float(a) == THC_half2float(b);
-#endif
+    // has to be explicitly casted to float for now, otherwise get error: more than one operator "==" matches these operands
+    // Note: find the overloading for == and != (probably THCTensorTypeUtils.cuh) and resolve
+    return static_cast<float>(static_cast<at::Half>(a)) == static_cast<float>(static_cast<at::Half>(b));
   }
 
   static inline __host__ __device__ bool ne(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return fa != fb;
-#else // __CUDA_ARCH__
-    return THC_half2float(a) != THC_half2float(b);
-#endif
+    // has to be explicitly casted to float for now, otherwise get error: more than one operator "==" matches these operands
+    // Note: find the overloading for == and != (probably THCTensorTypeUtils.cuh) and resolve
+    return static_cast<float>(static_cast<at::Half>(a)) != static_cast<float>(static_cast<at::Half>(b));
   }
 
   static inline __host__ __device__ half exp(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(expf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(expf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(std::exp(static_cast<at::Half>(a)));
   }
-
+  
+  // note that exp10 is not in the std namespace. 
   static inline __host__ __device__ half exp10(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(exp10f(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(exp10f(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::exp10(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half log(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(logf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(logf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::log(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half log10(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(log10f(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(log10f(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::log10(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half log1p(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(log1pf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(log1pf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::log1p(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half log2(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(log2f(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(log2f(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::log2(static_cast<at::Half>(a)));
   }
 
-static inline __host__ __device__ half lgamma(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(lgammaf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(lgammaf(THC_half2float(a)));
-#endif
+  static inline __host__ __device__ half lgamma(half a) {
+    return static_cast<at::Half>(::lgamma(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half expm1(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(expm1f(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(expm1f(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::expm1(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half cos(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(cosf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(cosf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::cos(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half sin(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(sinf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(sinf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::sin(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half sqrt(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(sqrtf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(sqrtf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::sqrt(static_cast<at::Half>(a)));
   }
 
+  // note that rsqrt is not in the std namespace. 
   static inline __host__ __device__ half rsqrt(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(rsqrtf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(rsqrtf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::rsqrt(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half ceil(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(ceilf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(ceilf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::ceil(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half floor(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(floorf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(floorf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::floor(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half trunc(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(truncf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(truncf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::trunc(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half neg(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(-fa);
-#else // __CUDA_ARCH__
-    return THC_float2half(-(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(-(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half acos(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(acosf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(acosf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::acos(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half cosh(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(coshf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(coshf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::cosh(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half asin(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(asinf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(asinf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::asin(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half sinh(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(sinhf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(sinhf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::sinh(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half tan(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(tanf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(tanf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::tan(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half atan(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(atanf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(atanf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::atan(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half tanh(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(tanhf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(tanhf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::tanh(static_cast<at::Half>(a)));
   }
 
 
    static inline __host__ __device__ half erf(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(erff(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(erff(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::erf(static_cast<at::Half>(a)));
   }
 
 
    static inline __host__ __device__ half erfc(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(erfcf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(erfcf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::erfc(static_cast<at::Half>(a)));
   }
 
-
+  // note that erfinv is not in the std namespace.
   static inline __host__ __device__ half erfinv(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(erfinvf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(erfinvf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::erfinv(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half abs(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(fabs(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(fabs(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::abs(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half round(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(roundf(fa));
-#else // __CUDA_ARCH__
-    return THC_float2half(roundf(THC_half2float(a)));
-#endif
+    return static_cast<at::Half>(::round(static_cast<at::Half>(a)));
   }
 
   static inline __host__ __device__ half frac(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(fa - truncf(fa));
-#else // __CUDA_ARCH__
-    float fa = THC_half2float(a);
-    return THC_float2half(fa - floorf(fa));
-#endif
+    #ifdef __CUDA_ARCH__
+        return static_cast<at::Half>(a) - static_cast<at::Half>(::trunc(static_cast<at::Half>(a)));
+    #else // __CUDA_ARCH__
+        return static_cast<at::Half>(a) - static_cast<at::Half>(::floor(static_cast<at::Half>(a)));
+    #endif
   }
 
   static inline __host__ __device__ half cinv(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return __float2half(1.0f / fa);
-#else // __CUDA_ARCH__
-    return THC_float2half(1.0f / THC_half2float(a));
-#endif
+    return static_cast<at::Half>(1.0f / static_cast<at::Half>(a));
   }
 
   static inline __host__ __device__ half add(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return __float2half( fa + fb );
-#else // __CUDA_ARCH__
-    return THC_float2half(THC_half2float(a) + THC_half2float(b));
-#endif
+    return static_cast<at::Half>(a) + static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ half div(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return __float2half( fa / fb );
-#else // __CUDA_ARCH__
-    return THC_float2half(THC_half2float(a) / THC_half2float(b));
-#endif
+    return static_cast<at::Half>(a) / static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ half mul(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return __float2half( fa * fb );
-#else // __CUDA_ARCH__
-    return THC_float2half(THC_half2float(a) * THC_half2float(b));
-#endif
+    return static_cast<at::Half>(a) * static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ half sub(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return __float2half( fa - fb );
-#else // __CUDA_ARCH__
-    return THC_float2half(THC_half2float(a) - THC_half2float(b));
-#endif
+    return static_cast<at::Half>(a) - static_cast<at::Half>(b);
   }
 
   static inline __host__ __device__ half pow(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return __float2half(powf(fa, fb));
-#else // __CUDA_ARCH__
-    return THC_float2half(powf(THC_half2float(a), THC_half2float(b)));
-#endif
+    return static_cast<at::Half>(::pow(static_cast<at::Half>(a), static_cast<at::Half>(b)));
   }
 
   static inline __host__ __device__ half atan2(half a, half b) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    float fb = __half2float(b);
-    return __float2half(atan2f(fa, fb));
-#else // __CUDA_ARCH__
-    return THC_float2half(atan2f(THC_half2float(a), THC_half2float(b)));
-#endif
+    return static_cast<at::Half>(::atan2(static_cast<at::Half>(a), static_cast<at::Half>(b)));
   }
 
   static inline __host__ __device__ bool isnan(half a) {
-    // implemented using that a!=a if and only if a is nan
-    return ne(a, a);
+    #ifdef _MSC_VER
+      // Windows requires this explicit conversion. The reason is unclear
+      // related issue with clang: https://reviews.llvm.org/D37906
+      return ::isnan((float)static_cast<at::Half>(a));
+    #else
+      return ::isnan(static_cast<at::Half>(a));
+    #endif
   }
 
   static inline __host__ __device__ bool isinf(half a) {
-#ifdef __CUDA_ARCH__
-    float fa = __half2float(a);
-    return ::isinf(fa);
-#else // __CUDA_ARCH__
-    return ::isinf(THC_half2float(a));
-#endif
+    #ifdef _MSC_VER
+      // Windows requires this explicit conversion. The reason is unclear
+      // related issue with clang: https://reviews.llvm.org/D37906
+      return ::isinf((float)static_cast<at::Half>(a));
+    #else
+      return ::isinf(static_cast<at::Half>(a));
+    #endif
   }
 
 };
 
+// DEPRECATED: use math functions from std and cuda math API (if needed)
+//             note that the functions exp10,rsqrt,erfinv,frac and cinv
+//             are not in the std namespace
 template <>
 struct THCNumerics<float> {
-  static inline __host__ __device__ float min() { return -FLT_MAX; }
-  static inline __host__ __device__ float max() { return FLT_MAX; }
+  static inline __host__ __device__ float min() { return at::numeric_limits<float>::lowest(); }
+  static inline __host__ __device__ float max() { return at::numeric_limits<float>::max(); }
 
   static inline __host__ __device__ bool lt(float a, float b) { return a < b; }
   static inline __host__ __device__ bool le(float a, float b) { return a <= b; }
@@ -623,10 +419,13 @@
   static inline __host__ __device__  bool isinf(float a) { return ::isinf(a); }
 };
 
+// DEPRECATED: use math functions from std and cuda math API (if needed)
+//             note that the functions exp10,rsqrt,erfinv,frac and cinv
+//             are not in the std namespace
 template <>
 struct THCNumerics<double> {
-  static inline __host__ __device__ double min() { return -DBL_MAX; }
-  static inline __host__ __device__ double max() { return DBL_MAX; }
+  static inline __host__ __device__ double min() { return at::numeric_limits<double>::lowest(); }
+  static inline __host__ __device__ double max() { return at::numeric_limits<double>::max(); }
 
   static inline __host__ __device__ bool lt(double a, double b) { return a < b; }
   static inline __host__ __device__ bool le(double a, double b) { return a <= b; }
@@ -677,10 +476,15 @@
   static inline __host__ __device__  bool isinf(double a) { return ::isinf(a); }
 };
 
-/// `half` has some type conversion issues associated with it, since it
-/// is a struct without a constructor/implicit conversion constructor.
-/// We use this to convert scalar values to the given type that the
-/// tensor expects.
+// WARNING: The following note is deprecated
+///       `half` has some type conversion issues associated with it, since it
+///        is a struct without a constructor/implicit conversion constructor.
+///        We use this to convert scalar values to the given type that the
+///        tensor expects.
+///
+/// at::Half has implicit conversions for float and __half types. Moreover
+/// it has constructors for __half and float types.
+
 template <typename In, typename Out>
 struct ScalarConvert {
   static __host__ __device__ Out to(const In v) { return (Out) v; }
@@ -715,6 +519,7 @@
   }
 };
 
+// DEPRECATED: use static_cast in kernels instead of scalar_cast
 template <typename T, typename U>
 __host__ __device__ T scalar_cast(U u) {
   return ScalarConvert<U, T>::to(u);
diff --git a/aten/src/THCUNN/THCHalfAutoNumerics.cuh b/aten/src/THCUNN/THCHalfAutoNumerics.cuh
index fff37f8..5f8fda8 100644
--- a/aten/src/THCUNN/THCHalfAutoNumerics.cuh
+++ b/aten/src/THCUNN/THCHalfAutoNumerics.cuh
@@ -4,6 +4,9 @@
 #include "THCHalf.h"
 #include "THCNumerics.cuh"
 
+// WARNING: THCNumerics is being deprecated. Read the comments and function usage 
+//          in THCNumerics to learn about the deprecation
+//      
 // Half numerics functions defined as free functions, so cunn code can be
 //written generically, i.e. without excessive calling of THCNumerics<half> functions.
 
diff --git a/aten/tools/run_tests.sh b/aten/tools/run_tests.sh
index c341b88..d2669029 100755
--- a/aten/tools/run_tests.sh
+++ b/aten/tools/run_tests.sh
@@ -28,6 +28,9 @@
 if [[ -x ./stream_test ]]; then
   ./stream_test
 fi
+if [[ -x ./cuda_half_test ]]; then
+  ./cuda_half_test
+fi
 if [ "$VALGRIND" == "ON" ]
 then
   valgrind --suppressions="$VALGRIND_SUP" --error-exitcode=1 ./basic "[cpu]"
diff --git a/tools/build_pytorch_libs.sh b/tools/build_pytorch_libs.sh
index 268ec12..994a96a 100755
--- a/tools/build_pytorch_libs.sh
+++ b/tools/build_pytorch_libs.sh
@@ -264,6 +264,7 @@
       -DCAFFE2_STATIC_LINK_CUDA=$CAFFE2_STATIC_LINK_CUDA \
       -DUSE_ROCM=$USE_ROCM \
       -DUSE_NNPACK=$USE_NNPACK \
+      -DCUDA_DEVICE_DEBUG=$CUDA_DEVICE_DEBUG \
       -DCUDNN_INCLUDE_DIR=$CUDNN_INCLUDE_DIR \
       -DCUDNN_LIB_DIR=$CUDNN_LIB_DIR \
       -DCUDNN_LIBRARY=$CUDNN_LIBRARY \