Move THCTensor_(cauchy) to ATen (#21289)
Summary:
## Effective Bandwidth Benchmark
- using https://gist.github.com/syed-ahmed/f8b7384d642f4bce484228b508b4bc68
- on V100
### Float Type
#### Before:
```
cauchy, size, elements 65536 forward 4.980564117431641e-06 bandwidth (GB/s) 52.63339529803734
cauchy, size, elements 131072 forward 6.232261657714844e-06 bandwidth (GB/s) 84.12483762631982
cauchy, size, elements 262144 forward 9.548664093017577e-06 bandwidth (GB/s) 109.81389540833959
cauchy, size, elements 524288 forward 1.59454345703125e-05 bandwidth (GB/s) 131.52052963827754
cauchy, size, elements 1048576 forward 2.86865234375e-05 bandwidth (GB/s) 146.21165262978724
cauchy, size, elements 2097152 forward 5.4748058319091796e-05 bandwidth (GB/s) 153.2220184158516
cauchy, size, elements 4194304 forward 0.00010075807571411133 bandwidth (GB/s) 166.50988897012377
cauchy, size, elements 8388608 forward 0.0001935744285583496 bandwidth (GB/s) 173.34124269355965
cauchy, size, elements 16777216 forward 0.00038077831268310545 bandwidth (GB/s) 176.24129779641603
cauchy, size, elements 33554432 forward 0.0006851387023925781 bandwidth (GB/s) 195.8986224705994
```
#### After:
```
cauchy, size, elements 65536 forward 6.077289581298828e-06 bandwidth (GB/s) 43.13501874366419
cauchy, size, elements 131072 forward 6.2131881713867184e-06 bandwidth (GB/s) 84.38308731972373
cauchy, size, elements 262144 forward 6.46829605102539e-06 bandwidth (GB/s) 162.11008150033175
cauchy, size, elements 524288 forward 6.8783760070800785e-06 bandwidth (GB/s) 304.8905726935182
cauchy, size, elements 1048576 forward 9.505748748779296e-06 bandwidth (GB/s) 441.23867681003264
cauchy, size, elements 2097152 forward 1.5070438385009766e-05 bandwidth (GB/s) 556.6266744001266
cauchy, size, elements 4194304 forward 2.4406909942626954e-05 bandwidth (GB/s) 687.396152951685
cauchy, size, elements 8388608 forward 4.6243667602539064e-05 bandwidth (GB/s) 725.6005792706125
cauchy, size, elements 16777216 forward 9.100198745727539e-05 bandwidth (GB/s) 737.4439380404413
cauchy, size, elements 33554432 forward 0.00017449140548706055 bandwidth (GB/s) 769.1939188944922
```
### Double Type
#### Before:
```
cauchy, size, elements 65536 forward 4.885196685791015e-06 bandwidth (GB/s) 53.660889593753055
cauchy, size, elements 131072 forward 6.229877471923828e-06 bandwidth (GB/s) 84.15703235943361
cauchy, size, elements 262144 forward 9.605884552001953e-06 bandwidth (GB/s) 109.15975455706132
cauchy, size, elements 524288 forward 1.5976428985595704e-05 bandwidth (GB/s) 131.26537863315923
cauchy, size, elements 1048576 forward 2.9621124267578124e-05 bandwidth (GB/s) 141.59840666786866
cauchy, size, elements 2097152 forward 5.5103302001953126e-05 bandwidth (GB/s) 152.23421637604707
cauchy, size, elements 4194304 forward 0.00010124444961547851 bandwidth (GB/s) 165.70998275677383
cauchy, size, elements 8388608 forward 0.0001944279670715332 bandwidth (GB/s) 172.58027487195184
cauchy, size, elements 16777216 forward 0.00034950494766235353 bandwidth (GB/s) 192.01119883668116
cauchy, size, elements 33554432 forward 0.0007002186775207519 bandwidth (GB/s) 191.67973135938277
```
#### After:
```
cauchy, size, elements 65536 forward 5.91278076171875e-06 bandwidth (GB/s) 44.33514628129032
cauchy, size, elements 131072 forward 6.234645843505859e-06 bandwidth (GB/s) 84.09266751632889
cauchy, size, elements 262144 forward 7.433891296386719e-06 bandwidth (GB/s) 141.05344807902503
cauchy, size, elements 524288 forward 1.1401176452636719e-05 bandwidth (GB/s) 183.94171941045587
cauchy, size, elements 1048576 forward 1.960039138793945e-05 bandwidth (GB/s) 213.99082890665372
cauchy, size, elements 2097152 forward 3.434181213378906e-05 bandwidth (GB/s) 244.26806504326578
cauchy, size, elements 4194304 forward 6.517410278320313e-05 bandwidth (GB/s) 257.4215107465028
cauchy, size, elements 8388608 forward 0.0001229524612426758 bandwidth (GB/s) 272.9057365819818
cauchy, size, elements 16777216 forward 0.00023239374160766602 bandwidth (GB/s) 288.77225150621814
cauchy, size, elements 33554432 forward 0.00046050310134887696 bandwidth (GB/s) 291.4589013773367
```
Resubmit of https://github.com/pytorch/pytorch/pull/20622
Pull Request resolved: https://github.com/pytorch/pytorch/pull/21289
Differential Revision: D15622713
Pulled By: ezyang
fbshipit-source-id: abe8bd57794bd1c3a0b92395367a9653c5d0f2db
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index bba874a..88ca447 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -2681,7 +2681,6 @@
- floating_point
backends:
- CPU
- - CUDA
cname: cauchy
variants: function
return: self
diff --git a/aten/src/ATen/native/cuda/Distributions.cu b/aten/src/ATen/native/cuda/Distributions.cu
index b482a68..5fd2e95 100644
--- a/aten/src/ATen/native/cuda/Distributions.cu
+++ b/aten/src/ATen/native/cuda/Distributions.cu
@@ -516,6 +516,36 @@
});
}
+void cauchy_kernel_cuda(TensorIterator& iter, double median_, double sigma_, Generator* gen_) {
+ auto gen = check_generator<CUDAGenerator>(gen_, &globalContext().defaultGenerator(kCUDA));
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "cauchy_cuda", [&] {
+ using accscalar_t = at::acc_type<scalar_t, true>;
+ auto median = static_cast<accscalar_t>(median_);
+ auto sigma = static_cast<accscalar_t>(sigma_);
+ if (std::is_same<scalar_t, double>::value) {
+ // define lambda for cauchy transformation
+ auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
+ return static_cast<scalar_t>(median + sigma *
+ ::tan(static_cast<accscalar_t>(M_PI) * (rand-static_cast<accscalar_t>(0.5))));
+ };
+ distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls/2>(iter,
+ gen,
+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform2_double(state); },
+ cauchy_func);
+ } else {
+ // use __tanf fast approximation for peak bandwidth
+ auto cauchy_func = [median, sigma] __device__ (accscalar_t rand) {
+ return static_cast<scalar_t>(median + sigma *
+ __tanf(static_cast<accscalar_t>(M_PI) * (rand-static_cast<accscalar_t>(0.5))));
+ };
+ distribution_nullary_kernel<scalar_t, accscalar_t, curand4_engine_calls>(iter,
+ gen,
+ [] __device__ (curandStatePhilox4_32_10_t* state) { return curand_uniform4(state); },
+ cauchy_func);
+ }
+ });
+}
+
Tensor& uniform_cuda_(Tensor& self, double from, double to, Generator* gen) {
auto iter = TensorIterator::nullary_op(self);
uniform_kernel_cuda(*iter, from, to, gen);
@@ -595,4 +625,10 @@
return ret;
}
+Tensor& cauchy_cuda_(Tensor& self, double median, double sigma, Generator* gen) {
+ auto iter = TensorIterator::nullary_op(self);
+ cauchy_kernel_cuda(*iter, median, sigma, gen);
+ return self;
+}
+
}} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 75bb3f6..7ba55a6 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3214,7 +3214,7 @@
variants: method
dispatch:
CPU: legacy::cpu::_th_cauchy_
- CUDA: legacy::cuda::_th_cauchy_
+ CUDA: cauchy_cuda_
- func: log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)
variants: method
diff --git a/aten/src/THC/THCTensorRandom.cu b/aten/src/THC/THCTensorRandom.cu
index 0817d34..53a9103 100644
--- a/aten/src/THC/THCTensorRandom.cu
+++ b/aten/src/THC/THCTensorRandom.cu
@@ -132,11 +132,7 @@
GENERATE_KERNEL1(generate_exponential, float, double lambda, float, curand_uniform, (float)(-1. / lambda * log(x)))
GENERATE_KERNEL1(generate_exponential, double, double lambda, double, curand_uniform_double, (double)(-1. / lambda * log(x)))
-GENERATE_KERNEL2(generate_cauchy, float, double median, double sigma, float, curand_uniform, (float)(median + sigma * tan(M_PI*(x-0.5))))
-GENERATE_KERNEL2(generate_cauchy, double, double median, double sigma, double, curand_uniform_double, (double)(median + sigma * tan(M_PI*(x-0.5))))
-
GENERATE_KERNEL1(generate_exponential, at::Half, double lambda, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(-1. / lambda * log(x)))))
-GENERATE_KERNEL2(generate_cauchy, at::Half, double median, double sigma, float, curand_uniform, (ScalarConvert<float, at::Half>::to((float)(median + sigma * tan(M_PI*(x-0.5))))))
#include <THC/generic/THCTensorRandom.cu>
#include <THC/THCGenerateAllTypes.h>
diff --git a/aten/src/THC/generic/THCTensorRandom.cu b/aten/src/THC/generic/THCTensorRandom.cu
index c0e5cce..af4bca3 100644
--- a/aten/src/THC/generic/THCTensorRandom.cu
+++ b/aten/src/THC/generic/THCTensorRandom.cu
@@ -42,22 +42,6 @@
THCTensor_(freeCopyTo)(state, self, self_);
};
-void THCTensor_(cauchy)(THCState* state, THCTensor *self_, double median, double sigma)
-{
- THCAssertSameGPU(THCTensor_(checkGPU)(state, 1, self_));
- ptrdiff_t size = THCTensor_(nElement)(state, self_);
- if (size == 0) return;
- THCGenerator* gen = THCRandom_getGenerator(state);
-
- THCTensor *self = THCTensor_(newContiguous)(state, self_);
- scalar_t *data = THCTensor_(data)(state, self);
-
- generate_cauchy<<<NUM_BLOCKS, BLOCK_SIZE, 0, THCState_getCurrentStream(state)>>>(
- gen->state.gen_states, size, data, median, sigma);
-
- THCTensor_(freeCopyTo)(state, self, self_);
-};
-
void THCTensor_(renormRows)(struct THCState* state,
THCTensor* t) {
THAssert(THCTensor_(nDimensionLegacyAll)(state, t) == 2);
diff --git a/aten/src/THC/generic/THCTensorRandom.h b/aten/src/THC/generic/THCTensorRandom.h
index 20038b9..90d0760 100644
--- a/aten/src/THC/generic/THCTensorRandom.h
+++ b/aten/src/THC/generic/THCTensorRandom.h
@@ -6,7 +6,6 @@
THC_API void THCTensor_(logNormal)(struct THCState *state, THCTensor *self, double mean, double stdv);
THC_API void THCTensor_(exponential)(struct THCState *state, THCTensor *self, double lambda);
-THC_API void THCTensor_(cauchy)(struct THCState *state, THCTensor *self, double median, double sigma);
THC_API void THCTensor_(multinomial)(struct THCState *state, THCudaLongTensor *self, THCTensor *prob_dist, int n_sample, int with_replacement);
THC_API void THCTensor_(multinomialAliasSetup)(struct THCState *state, THCTensor *probs, THCudaLongTensor *J, THCTensor *q);
THC_API void THCTensor_(multinomialAliasDraw)(THCState *state, THCudaLongTensor *self, THCTensor *_q, THCudaLongTensor *_J, int n_sample);