blob: c96b7f3929bc8939a74b2440d98d9681a8e94964 [file] [log] [blame]
#include <ATen/Dispatch.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/cuda/Loops.cuh>
// NOTE: CUDA on Windows requires that the enclosing function
// of a __device__ lambda not have internal linkage.
namespace at { namespace native {
template<typename scalar_t>
struct CompareGEFunctor {
__device__ __forceinline__ bool operator() (scalar_t a, scalar_t b) const {
return a >= b;
}
};
void ge_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND3(kHalf, kBFloat16, kBool, iter.common_dtype(), "ge_cuda", [&]() {
gpu_kernel_with_scalars(iter, CompareGEFunctor<scalar_t>());
});
}
REGISTER_DISPATCH(ge_stub, &ge_kernel_cuda);
}} // namespace at::native