blob: a3621bfe7c55fbd7b367d793609fba8c1c9574a2 [file] [log] [blame]
#include <torch/torch.h>
#include <THC/THCNumerics.cuh>
template <typename T, typename U>
__global__ void half_test_kernel(const T* input, U* output) {
if (input[0] < input[1] || input[0] >= input[1]) {
output[0] = 123;
}
}
at::Tensor half_test(at::Tensor input) {
auto output = at::empty(1, input.options().dtype(at::kFloat));
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.type(), "half_test", [&] {
half_test_kernel<scalar_t>
<<<1, 1>>>(input.data<scalar_t>(), output.data<float>());
});
return output;
}