| #include <torch/torch.h> | |
| #include <THC/THCNumerics.cuh> | |
| template <typename T, typename U> | |
| __global__ void half_test_kernel(const T* input, U* output) { | |
| if (input[0] < input[1] || input[0] >= input[1]) { | |
| output[0] = 123; | |
| } | |
| } | |
| at::Tensor half_test(at::Tensor input) { | |
| auto output = at::empty(1, input.options().dtype(at::kFloat)); | |
| AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] { | |
| half_test_kernel<scalar_t> | |
| <<<1, 1>>>(input.data<scalar_t>(), output.data<float>()); | |
| }); | |
| return output; | |
| } |