Add flags to fix half comparison and test (#11395)
Summary:
The controller you requested could not be found. found there are some issues when using comparison operators for half types when certain THC header are included. I was able to reproduce and added a test. I also fix the issue by adding the proper definitions to avoid this issue.
Reported in https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
Related: https://github.com/pytorch/tutorials/pull/292
soumith fmassa
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11395
Differential Revision: D9725102
Pulled By: goldsborough
fbshipit-source-id: 630425829046bbebea3409bb792a9d62c91f41ad
diff --git a/.gitignore b/.gitignore
index da78355..110046e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -35,6 +35,7 @@
test/data/legacy_serialized.pt
test/data/linear.pt
test/htmlcov
+test/cpp_extensions/install/
third_party/build/
tools/shared/_utils_internal.py
torch.egg-info/
diff --git a/test/cpp_extensions/half_support.cpp b/test/cpp_extensions/half_support.cpp
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/test/cpp_extensions/half_support.cpp
diff --git a/test/cpp_extensions/half_support.cu b/test/cpp_extensions/half_support.cu
new file mode 100644
index 0000000..a3621bf
--- /dev/null
+++ b/test/cpp_extensions/half_support.cu
@@ -0,0 +1,19 @@
+#include <torch/torch.h>
+
+#include <THC/THCNumerics.cuh>
+
+template <typename T, typename U>
+__global__ void half_test_kernel(const T* input, U* output) {
+ if (input[0] < input[1] || input[0] >= input[1]) {
+ output[0] = 123;
+ }
+}
+
+at::Tensor half_test(at::Tensor input) {
+ auto output = at::empty(1, input.options().dtype(at::kFloat));
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
+ half_test_kernel<scalar_t>
+ <<<1, 1>>>(input.data<scalar_t>(), output.data<float>());
+ });
+ return output;
+}
diff --git a/test/test_cpp_extensions.py b/test/test_cpp_extensions.py
index a5312cd..f24571e 100755
--- a/test/test_cpp_extensions.py
+++ b/test/test_cpp_extensions.py
@@ -274,6 +274,47 @@
torch.empty(2, 2, dtype=torch.complex64)
+ @unittest.skipIf(not TEST_CUDA, "CUDA not found")
+ def test_half_support(self):
+ '''
+ Checks for an issue with operator< ambiguity for half when certain
+ THC headers are included.
+
+ See https://github.com/pytorch/pytorch/pull/10301#issuecomment-416773333
+ for the corresponding issue.
+ '''
+ cuda_source = '''
+ #include <THC/THCNumerics.cuh>
+
+ template<typename T, typename U>
+ __global__ void half_test_kernel(const T* input, U* output) {
+ if (input[0] < input[1] || input[0] >= input[1]) {
+ output[0] = 123;
+ }
+ }
+
+ at::Tensor half_test(at::Tensor input) {
+ auto output = at::empty(1, input.options().dtype(at::kFloat));
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
+ half_test_kernel<scalar_t><<<1, 1>>>(
+ input.data<scalar_t>(),
+ output.data<float>());
+ });
+ return output;
+ }
+ '''
+
+ module = torch.utils.cpp_extension.load_inline(
+ name='half_test_extension',
+ cpp_sources='at::Tensor half_test(at::Tensor input);',
+ cuda_sources=cuda_source,
+ functions=['half_test'],
+ verbose=True)
+
+ x = torch.randn(3, device='cuda', dtype=torch.half)
+ result = module.half_test(x)
+ self.assertEqual(result[0], 123)
+
if __name__ == '__main__':
common.run_tests()
diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py
index 0af8026..43edbd2 100644
--- a/torch/utils/cpp_extension.py
+++ b/torch/utils/cpp_extension.py
@@ -69,6 +69,12 @@
# it the below pattern.
BUILT_FROM_SOURCE_VERSION_PATTERN = re.compile(r'\d+\.\d+\.\d+\w+\+\w+')
+COMMON_NVCC_FLAGS = [
+ '-D__CUDA_NO_HALF_OPERATORS__',
+ '-D__CUDA_NO_HALF_CONVERSIONS__',
+ '-D__CUDA_NO_HALF2_OPERATORS__',
+]
+
def is_binary_build():
return not BUILT_FROM_SOURCE_VERSION_PATTERN.match(torch.version.__version__)
@@ -165,7 +171,7 @@
self.compiler.set_executable('compiler_so', nvcc)
if isinstance(cflags, dict):
cflags = cflags['nvcc']
- cflags += ['--compiler-options', "'-fPIC'"]
+ cflags = COMMON_NVCC_FLAGS + ['--compiler-options', "'-fPIC'"] + cflags
elif isinstance(cflags, dict):
cflags = cflags['cxx']
# NVCC does not allow multiple -std to be passed, so we avoid
@@ -831,7 +837,7 @@
flags = ['cflags = {}'.format(' '.join(cflags))]
if with_cuda:
- cuda_flags = common_cflags
+ cuda_flags = common_cflags + COMMON_NVCC_FLAGS
if sys.platform == 'win32':
cuda_flags = _nt_quote_args(cuda_flags)
else: