Add flags to fix half comparison and test (#11395)

Summary:
The controller you requested could not be found.  found there are some issues when using comparison operators for half types when certain THC header are included. I was able to reproduce and added a test. I also fix the issue by adding the proper definitions to avoid this issue.

Reported in https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
Related: https://github.com/pytorch/tutorials/pull/292

soumith fmassa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11395

Differential Revision: D9725102

Pulled By: goldsborough

fbshipit-source-id: 630425829046bbebea3409bb792a9d62c91f41ad
diff --git a/.gitignore b/.gitignore
index da78355..110046e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,6 +35,7 @@
 test/data/legacy_serialized.pt
 test/data/linear.pt
 test/htmlcov
+test/cpp_extensions/install/
 third_party/build/
 tools/shared/_utils_internal.py
 torch.egg-info/
diff --git a/test/cpp_extensions/half_support.cpp b/test/cpp_extensions/half_support.cpp
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/cpp_extensions/half_support.cpp
diff --git a/test/cpp_extensions/half_support.cu b/test/cpp_extensions/half_support.cu
new file mode 100644
index 0000000..a3621bf
--- /dev/null
+++ b/test/cpp_extensions/half_support.cu
@@ -0,0 +1,19 @@
+#include <torch/torch.h>
+
+#include <THC/THCNumerics.cuh>
+
+template <typename T, typename U>
+__global__ void half_test_kernel(const T* input, U* output) {
+  if (input[0] < input[1] || input[0] >= input[1]) {
+    output[0] = 123;
+  }
+}
+
+at::Tensor half_test(at::Tensor input) {
+  auto output = at::empty(1, input.options().dtype(at::kFloat));
+  AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
+    half_test_kernel<scalar_t>
+        <<<1, 1>>>(input.data<scalar_t>(), output.data<float>());
+  });
+  return output;
+}
diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py
index a5312cd..f24571e 100755
--- a/test/test_cpp_extensions.py
+++ b/test/test_cpp_extensions.py
@@ -274,6 +274,47 @@
 
         torch.empty(2, 2, dtype=torch.complex64)
 
+    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
+    def test_half_support(self):
+        '''
+        Checks for an issue with operator< ambiguity for half when certain
+        THC headers are included.
+
+        See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
+        for the corresponding issue.
+        '''
+        cuda_source = '''
+        #include <THC/THCNumerics.cuh>
+
+        template<typename T, typename U>
+        __global__ void half_test_kernel(const T* input, U* output) {
+            if (input[0] < input[1] || input[0] >= input[1]) {
+                output[0] = 123;
+            }
+        }
+
+        at::Tensor half_test(at::Tensor input) {
+            auto output = at::empty(1, input.options().dtype(at::kFloat));
+            AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
+                half_test_kernel<scalar_t><<<1, 1>>>(
+                    input.data<scalar_t>(),
+                    output.data<float>());
+            });
+            return output;
+        }
+        '''
+
+        module = torch.utils.cpp_extension.load_inline(
+            name='half_test_extension',
+            cpp_sources='at::Tensor half_test(at::Tensor input);',
+            cuda_sources=cuda_source,
+            functions=['half_test'],
+            verbose=True)
+
+        x = torch.randn(3, device='cuda', dtype=torch.half)
+        result = module.half_test(x)
+        self.assertEqual(result[0], 123)
+
 
 if __name__ == '__main__':
     common.run_tests()
diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py
index 0af8026..43edbd2 100644
--- a/torch/utils/cpp_extension.py
+++ b/torch/utils/cpp_extension.py
@@ -69,6 +69,12 @@
 # it the below pattern.
 BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
 
+COMMON_NVCC_FLAGS = [
+    '-D__CUDA_NO_HALF_OPERATORS__',
+    '-D__CUDA_NO_HALF_CONVERSIONS__',
+    '-D__CUDA_NO_HALF2_OPERATORS__',
+]
+
 
 def is_binary_build():
     return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
@@ -165,7 +171,7 @@
                     self.compiler.set_executable('compiler_so', nvcc)
                     if isinstance(cflags, dict):
                         cflags = cflags['nvcc']
-                    cflags += ['--compiler-options', "'-fPIC'"]
+                    cflags = COMMON_NVCC_FLAGS + ['--compiler-options', "'-fPIC'"] + cflags
                 elif isinstance(cflags, dict):
                     cflags = cflags['cxx']
                 # NVCC does not allow multiple -std to be passed, so we avoid
@@ -831,7 +837,7 @@
     flags = ['cflags = {}'.format(' '.join(cflags))]
 
     if with_cuda:
-        cuda_flags = common_cflags
+        cuda_flags = common_cflags + COMMON_NVCC_FLAGS
         if sys.platform == 'win32':
             cuda_flags = _nt_quote_args(cuda_flags)
         else: