Move the CUDA implementation of log2 to ATen. (#26769)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26769
Fix #24589
Test Plan: Imported from OSS
Differential Revision: D17960122
Pulled By: VitalyFedyunin
fbshipit-source-id: 58dff236886bbf3a0a152d7422aa8a5c478ee1de
diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap
index 4da3e87..0c5da08 100644
--- a/aten/src/ATen/Declarations.cwrap
+++ b/aten/src/ATen/Declarations.cwrap
@@ -777,20 +777,6 @@
- THTensor* self
]]
[[
- name: _th_log2
- cname: log2
- types:
- - floating_point
- backends:
- - CUDA
- variants: function
- return: argument 0
- arguments:
- - arg: THTensor* result
- output: True
- - THTensor* self
-]]
-[[
name: _th_exp
cname: exp
types:
diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp
index 3988b5b..90051bb 100644
--- a/aten/src/ATen/native/UnaryOps.cpp
+++ b/aten/src/ATen/native/UnaryOps.cpp
@@ -81,6 +81,10 @@
Tensor log10(const Tensor& self) { return unary_op_impl(self, at::log10_out); }
Tensor& log10_(Tensor& self) { return unary_op_impl_(self, at::log10_out); }
+Tensor& log2_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log2_stub); }
+Tensor log2(const Tensor& self) { return unary_op_impl(self, at::log2_out); }
+Tensor& log2_(Tensor& self) { return unary_op_impl_(self, at::log2_out); }
+
Tensor& round_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, round_stub); }
Tensor round(const Tensor& self) { return unary_op_impl(self, at::round_out); }
Tensor& round_(Tensor& self) { return unary_op_impl_(self, at::round_out); }
@@ -279,7 +283,6 @@
IMPLEMENT_UNARY_OP_VEC(exp)
IMPLEMENT_UNARY_OP_VEC(frac)
IMPLEMENT_UNARY_OP_VEC(log1p)
-IMPLEMENT_UNARY_OP_VEC(log2)
IMPLEMENT_UNARY_OP_VEC(reciprocal)
IMPLEMENT_UNARY_OP_VEC(sigmoid)
IMPLEMENT_UNARY_OP_VEC(sin)
diff --git a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp
index 6193083..3bbefb5 100644
--- a/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp
+++ b/aten/src/ATen/native/cuda/CUDAUnaryOps.cpp
@@ -75,7 +75,6 @@
IMPLEMENT_UNARY_OP_PREQUEL(exp)
IMPLEMENT_UNARY_OP_PREQUEL(frac)
IMPLEMENT_UNARY_OP_PREQUEL(log1p)
-IMPLEMENT_UNARY_OP_PREQUEL(log2)
IMPLEMENT_UNARY_OP_PREQUEL(reciprocal)
IMPLEMENT_UNARY_OP_PREQUEL(sigmoid)
IMPLEMENT_UNARY_OP_PREQUEL(sin)
diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
index 54b922f..9e5c4df 100644
--- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
+++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu
@@ -73,6 +73,14 @@
});
}
+void log2_kernel_cuda(TensorIterator& iter) {
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(iter.dtype(), "log2_cuda", [&]() {
+ gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
+ return ::log2(a);
+ });
+ });
+}
+
void neg_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND(ScalarType::Half, iter.dtype(), "neg_cuda", [&]() {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
@@ -189,6 +197,7 @@
REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda);
REGISTER_DISPATCH(log_stub, &log_kernel_cuda);
REGISTER_DISPATCH(log10_stub, &log10_kernel_cuda);
+REGISTER_DISPATCH(log2_stub, &log2_kernel_cuda);
REGISTER_DISPATCH(neg_stub, &neg_kernel_cuda);
REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2fd7a81..1100383 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1637,15 +1637,12 @@
use_c10_dispatcher: unboxed_only
supports_named_tensor: True
variants: function, method
- dispatch:
- CPU: _log2__cpu
- CUDA: _log2__cuda
- func: log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
supports_named_tensor: True
dispatch:
- CPU: _log2_out_cpu
- CUDA: _log2_out_cuda
+ CPU: log2_out
+ CUDA: log2_out
- func: logdet(Tensor self) -> Tensor
use_c10_dispatcher: full
diff --git a/aten/src/TH/THGeneral.cpp b/aten/src/TH/THGeneral.cpp
index 060b1e6..18eebe9 100644
--- a/aten/src/TH/THGeneral.cpp
+++ b/aten/src/TH/THGeneral.cpp
@@ -198,11 +198,6 @@
#endif
}
-double THLog2(const double x)
-{
- return log2(x);
-}
-
THDescBuff _THSizeDesc(const int64_t *size, const int64_t ndim) {
const int L = TH_DESC_BUFF_LEN;
THDescBuff buf;
diff --git a/aten/src/TH/THGeneral.h.in b/aten/src/TH/THGeneral.h.in
index c50600d..09642b6 100644
--- a/aten/src/TH/THGeneral.h.in
+++ b/aten/src/TH/THGeneral.h.in
@@ -89,7 +89,6 @@
TH_API double THLog1p(const double x);
-TH_API double THLog2(const double x);
TH_API THDescBuff _THSizeDesc(const int64_t *size, const int64_t ndim);
TH_API TH_NO_RETURN void _THError(const char *file, const int line, const char *fmt, ...);
TH_API void _THAssertionFailed(const char *file, const int line, const char *exp, const char *fmt, ...);
@@ -168,12 +167,6 @@
inline double log1p(double x) { return THLog1p(x); }
#endif
-#if defined(_MSC_VER)
-__inline double log2(double x) { return THLog2(x); }
-#else
-inline double log2(double x) { return THLog2(x); }
-#endif
-
#define snprintf _snprintf
#define popen _popen
#define pclose _pclose
diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h
index 97b88cf..9ab7721 100644
--- a/aten/src/TH/generic/THTensorMath.h
+++ b/aten/src/TH/generic/THTensorMath.h
@@ -152,7 +152,6 @@
TH_API void THTensor_(sigmoid)(THTensor *r_, THTensor *t);
TH_API void THTensor_(log1p)(THTensor *r_, THTensor *t);
-TH_API void THTensor_(log2)(THTensor *r_, THTensor *t);
TH_API void THTensor_(exp)(THTensor *r_, THTensor *t);
TH_API void THTensor_(cos)(THTensor *r_, THTensor *t);
TH_API void THTensor_(acos)(THTensor *r_, THTensor *t);
diff --git a/aten/src/TH/generic/THVector.h b/aten/src/TH/generic/THVector.h
index d6fffd7..695f8e7 100644
--- a/aten/src/TH/generic/THVector.h
+++ b/aten/src/TH/generic/THVector.h
@@ -32,7 +32,6 @@
#if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE)
TH_API void THVector_(log1p)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
-TH_API void THVector_(log2)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
TH_API void THVector_(sigmoid)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
TH_API void THVector_(exp)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
TH_API void THVector_(erf)(scalar_t *y, const scalar_t *x, const ptrdiff_t n);
diff --git a/aten/src/TH/generic/THVectorDefault.cpp b/aten/src/TH/generic/THVectorDefault.cpp
index e847cda..ceb89a8 100644
--- a/aten/src/TH/generic/THVectorDefault.cpp
+++ b/aten/src/TH/generic/THVectorDefault.cpp
@@ -244,7 +244,6 @@
#endif
VECTOR_IMPLEMENT_FUNCTION(log1p,TH_MATH_NAME(log1p))
-VECTOR_IMPLEMENT_FUNCTION(log2,TH_MATH_NAME(log2))
VECTOR_IMPLEMENT_FUNCTION(sigmoid_DEFAULT,TH_MATH_NAME(TH_sigmoid))
VECTOR_IMPLEMENT_FUNCTION(exp,TH_MATH_NAME(exp))
VECTOR_IMPLEMENT_FUNCTION(erf,TH_MATH_NAME(erf))
diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in
index fac4e04..e311c15 100644
--- a/aten/src/THC/THCGeneral.h.in
+++ b/aten/src/THC/THCGeneral.h.in
@@ -4,7 +4,6 @@
#include <TH/THGeneral.h>
#include <TH/THAllocator.h>
#undef log1p
-#undef log2
#include <c10/cuda/CUDAStream.h>
diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh
index 7fb0416..8f10976 100644
--- a/aten/src/THC/THCNumerics.cuh
+++ b/aten/src/THC/THCNumerics.cuh
@@ -205,7 +205,6 @@
static inline __host__ __device__ at::Half exp(at::Half a) { return std::exp(a); }
static inline __host__ __device__ at::Half exp10(at::Half a) { return ::exp10(a); }
static inline __host__ __device__ at::Half log1p(at::Half a) { return ::log1p(a); }
- static inline __host__ __device__ at::Half log2(at::Half a) { return ::log2(a); }
static inline __host__ __device__ at::Half cos(at::Half a) { return ::cos(a); }
static inline __host__ __device__ at::Half sin(at::Half a) { return ::sin(a); }
static inline __host__ __device__ at::Half sqrt(at::Half a) { return ::sqrt(a); }
@@ -278,7 +277,6 @@
static inline __host__ __device__ float exp (float a) { return expf(a); }
static inline __host__ __device__ float exp10(float a) { return exp10f(a); }
static inline __host__ __device__ float log1p(float a) { return log1pf(a); }
- static inline __host__ __device__ float log2 (float a) { return log2f(a); }
static inline __host__ __device__ float cos (float a) { return cosf(a); }
static inline __host__ __device__ float sin (float a) { return sinf(a); }
static inline __host__ __device__ float sqrt (float a) { return sqrtf(a); }
@@ -326,7 +324,6 @@
static inline __host__ __device__ double exp (double a) { return ::exp(a); }
static inline __host__ __device__ double exp10(double a) { return ::exp10(a); }
static inline __host__ __device__ double log1p(double a) { return ::log1p(a); }
- static inline __host__ __device__ double log2 (double a) { return ::log2(a); }
static inline __host__ __device__ double cos (double a) { return ::cos(a); }
static inline __host__ __device__ double sin (double a) { return ::sin(a); }
static inline __host__ __device__ double sqrt (double a) { return ::sqrt(a); }
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.cu b/aten/src/THC/generic/THCTensorMathPointwise.cu
index 8f84d4e..1283bbd 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.cu
+++ b/aten/src/THC/generic/THCTensorMathPointwise.cu
@@ -199,7 +199,6 @@
#if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) || defined(THC_REAL_IS_HALF)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log1p, THCNumerics<scalar_t>::log1p, Real)
-IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( log2, THCNumerics<scalar_t>::log2, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( exp, THCNumerics<scalar_t>::exp, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( cos, THCNumerics<scalar_t>::cos, Real)
IMPLEMENT_CUDA_TENSOR_BASIC_FUNC( sin, THCNumerics<scalar_t>::sin, Real)
diff --git a/aten/src/THC/generic/THCTensorMathPointwise.h b/aten/src/THC/generic/THCTensorMathPointwise.h
index 3f688ef..b1e2639 100644
--- a/aten/src/THC/generic/THCTensorMathPointwise.h
+++ b/aten/src/THC/generic/THCTensorMathPointwise.h
@@ -17,7 +17,6 @@
THC_API void THCTensor_(sigmoid)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(log1p)(THCState *state, THCTensor *self, THCTensor *src);
-THC_API void THCTensor_(log2)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(exp)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(cos)(THCState *state, THCTensor *self, THCTensor *src);
THC_API void THCTensor_(acos)(THCState *state, THCTensor *self, THCTensor *src);
diff --git a/aten/src/THCUNN/THCHalfAutoNumerics.cuh b/aten/src/THCUNN/THCHalfAutoNumerics.cuh
index 9d97131..fdd8328 100644
--- a/aten/src/THCUNN/THCHalfAutoNumerics.cuh
+++ b/aten/src/THCUNN/THCHalfAutoNumerics.cuh
@@ -43,10 +43,6 @@
return THCNumerics<THHalf>::log1p(a);
}
-inline __host__ __device__ THHalf log2(THHalf a) {
- return THCNumerics<THHalf>::log2(a);
-}
-
inline __host__ __device__ THHalf pow(THHalf a, THHalf b) {
return THCNumerics<THHalf>::pow(a, b);
}
diff --git a/test/test_torch.py b/test/test_torch.py
index 2c6ac3e..35ce7c7 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -10945,7 +10945,7 @@
("log1p", positives, True, True, 'cpu'),
("log1p", positives, False, True, 'cuda'),
("log2", positives, True, True, 'cpu'),
- ("log2", positives, False, True, 'cuda'),
+ ("log2", positives, True, True, 'cuda'),
("neg", doubles, True, True, 'cpu'),
("neg", doubles, True, True, 'cuda'),
("reciprocal", doubles, True, True, 'cpu'),