Move THCTensor_(geometric) to ATen (#21298)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21298
ghimport-source-id: c0e2604aa25cc5da2b67293cafd88c2e77e476f9

Reviewed By: jerryzh168

Differential Revision: D15632932

Pulled By: ezyang

fbshipit-source-id: 248ca4b56967116f27174cda44893ecfe4ca9a99
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 2bd5acc..427901a 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -2735,7 +2735,6 @@
   name: _th_geometric_
   backends:
     - CPU
-    - CUDA
   cname: geometric
   variants: function
   return: self
diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu
index ddba24d..9dedce1 100644
--- a/aten/src/ATen/native/cuda/Distributions.cu
+++ b/aten/src/ATen/native/cuda/Distributions.cu
@@ -590,6 +590,32 @@
    });
 }
 
+void geometric_kernel_cuda(TensorIterator& iter, double p_, Generator* gen_) {
+  auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
+  AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Half, iter.dtype(), "geometric_cuda", [&] {
+    if (std::is_same<scalar_t, double>::value) {
+      // define lambda for geometric transformation
+      auto geometric_func = [p_] __device__ (double rand) {
+        return static_cast<scalar_t>(::ceil(::log(rand) / ::log(static_cast<double>(1.0)-p_)));
+      };
+      distribution_nullary_kernel<scalar_t, double, curand4_engine_calls/2>(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
+        geometric_func);
+    } else {
+      auto p = static_cast<float>(p_);
+      auto geometric_func = [p] __device__ (float rand) {
+        // use __logf fast approximation for peak bandwidth
+        return static_cast<scalar_t>(::ceil(__logf(rand) / __logf(static_cast<float>(1.0)-p)));
+      };
+      distribution_nullary_kernel<scalar_t, float, curand4_engine_calls>(iter,
+        gen,
+        [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
+        geometric_func);
+    }
+   });
+}
+
 Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
   auto iter = TensorIterator::nullary_op(self);
   uniform_kernel_cuda(*iter, from, to, gen);
@@ -681,4 +707,11 @@
   return self;
 }
 
+Tensor& geometric_cuda_(Tensor& self, double p, Generator* gen) {
+  TORCH_CHECK(0 < p && p < 1, "geometric_ expects p to be in (0, 1), but got p=", p);
+  auto iter = TensorIterator::nullary_op(self);
+  geometric_kernel_cuda(*iter, p, gen);
+  return self;
+}
+
 }} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2462eff..495f547 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3243,7 +3243,7 @@
   variants: method
   dispatch:
     CPU: legacy::cpu::_th_geometric_
-    CUDA: legacy::cuda::_th_geometric_
+    CUDA: geometric_cuda_
 
 # wrappers for TH functions
 
diff --git a/aten/src/THC/THCTensorRandom.cu b/aten/src/THC/THCTensorRandom.cu
index bd008dd..a67ffd26 100644
--- a/aten/src/THC/THCTensorRandom.cu
+++ b/aten/src/THC/THCTensorRandom.cu
@@ -101,20 +101,6 @@
   }
 }
 
-#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM)      \
-__global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1)    \
-{                                                                              \
-  int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x;                             \
-  int rounded_size = THCCeilDiv(size, BLOCK_SIZE) * BLOCK_SIZE;                \
-  for (int i = idx; i < rounded_size; i += BLOCK_SIZE * MAX_NUM_BLOCKS) {      \
-    CURAND_T x = CURAND_FUNC(&state[blockIdx.x]);                              \
-    if (i < size) {                                                            \
-      T y = TRANSFORM;                                                         \
-      result[i] = y;                                                           \
-    }                                                                          \
-  }                                                                            \
-}
-
 #define GENERATE_KERNEL2(NAME, T, ARG1, ARG2, CURAND_T, CURAND_FUNC, TRANSFORM)      \
 __global__ void NAME(curandStateMtgp32 *state, int size, T *result, ARG1, ARG2)    \
 {                                                                                    \
@@ -135,5 +121,4 @@
 #include <THC/generic/THCTensorRandom.cu>
 #include <THC/THCGenerateBoolType.h>
 
-#undef GENERATE_KERNEL1
 #undef GENERATE_KERNEL2
diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu
index 81f71a4..9cd880b 100644
--- a/aten/src/THC/generic/THCTensorRandom.cu
+++ b/aten/src/THC/generic/THCTensorRandom.cu
@@ -319,28 +319,6 @@
 
 #endif
 
-#if defined(THC_REAL_IS_DOUBLE)
-GENERATE_KERNEL1(generate_geometric, double, double p, double, curand_uniform_double, ceil(log(x) / log(1-p)))
-#else
-GENERATE_KERNEL1(generate_geometric, scalar_t, double p, float, curand_uniform, (ScalarConvert<float, scalar_t>::to(ceilf(logf(x) / log(1-p)))))
-#endif
-
-void THCTensor_(geometric)(THCState* state, THCTensor *self_, double p)
-{
-  THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
-  ptrdiff_t size = THCTensor_(nElement)(state, self_);
-  if (size == 0) return;
-  THCGenerator* gen = THCRandom_getGenerator(state);
-
-  THCTensor *self = THCTensor_(newContiguous)(state, self_);
-  scalar_t *data = THCTensor_(data)(state, self);
-
-  generate_geometric<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
-      gen->state.gen_states, size, data, p);
-
-  THCTensor_(freeCopyTo)(state, self, self_);
-};
-
 #undef NUM_BLOCKS
 
 #endif
diff --git a/aten/src/THC/generic/THCTensorRandom.h b/aten/src/THC/generic/THCTensorRandom.h
index 5569c86..f513b8d 100644
--- a/aten/src/THC/generic/THCTensorRandom.h
+++ b/aten/src/THC/generic/THCTensorRandom.h
@@ -10,7 +10,4 @@
 THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample);
 
 #endif
-
-THC_API void THCTensor_(geometric)(struct THCState *state, THCTensor *self, double p);
-
 #endif